Add Gemma 4 audio tower support (ASR via Conformer encoder)#192
Add Gemma 4 audio tower support (ASR via Conformer encoder)#192antmanler wants to merge 5 commits into
Conversation
Native Swift Gemma 4 audio support — end-to-end ASR on Apple Silicon: Engine Integration: - Gemma4ASREngine implementing ASREngineProtocol - One-pass mode: ASR=Gemma4 + LLM=Gemma4 → single model call - Settings UI: ASR/LLM engine selection, model size (2B/4B), t2s toggle - TalkApp: Gemma4 paths in init, recording, stop, processAudio Audio Tower (mlx-swift-lm fork): - Conformer encoder (12 layers) ported from Python mlx-vlm - Mel spectrogram extractor (Accelerate vDSP FFT) - Upstream PR: ml-explore/mlx-swift-lm#192 Key Bugs Fixed: - vDSP FFT output ×0.5 to match numpy rfft normalization - Audio mask polarity inverted (1=valid but encoder expects True=padding) - Audio token count from encoder subsampling, not duration estimate - Default audioTokenId=258881 missing from processor_config.json - EOS tokens: <turn|>=106, <channel|>=101 Co-Authored-By: Duoduo <duoduo@zhaob.in>
Native Swift Gemma 4 audio support — end-to-end ASR on Apple Silicon: Engine Integration: - Gemma4ASREngine implementing ASREngineProtocol - One-pass mode: ASR=Gemma4 + LLM=Gemma4 → single model call - Settings UI: ASR/LLM engine selection, model size (2B/4B), t2s toggle - TalkApp: Gemma4 paths in init, recording, stop, processAudio Audio Tower (mlx-swift-lm fork): - Conformer encoder (12 layers) ported from Python mlx-vlm - Mel spectrogram extractor (Accelerate vDSP FFT) - Upstream PR: ml-explore/mlx-swift-lm#192 Key Bugs Fixed: - vDSP FFT output ×0.5 to match numpy rfft normalization - Audio mask polarity inverted (1=valid but encoder expects True=padding) - Audio token count from encoder subsampling, not duration estimate - Default audioTokenId=258881 missing from processor_config.json - EOS tokens: <turn|>=106, <channel|>=101 Co-Authored-By: Duoduo <duoduo@zhaob.in>
|
Hi @antmanler, @davidkoski suggested we collaborate on combining our audio encoder work on top of #180. My PR #194 adds Gemma 3n audio support using the same conformer architecture. Happy to help integrate. |
668347f to
ed51bf6
Compare
|
Rebased onto current
The previous 21-commit branch has been preserved at tag Integration notes for #180
Correctness fixes preservedThe 8 numerical/correctness fixes from the original branch are all intact in the rebased commits:
Plus one new fix that I caught while validating against real model weights:
Verification
Ready for review. |
|
@vahsaechao thanks for reaching out about #194! After diffing both branches I think we can move both PRs in parallel without stepping on each other:
Zero file-level conflicts — whichever lands first, the other rebases trivially. For the shared Conformer code: the Gemma 4 and Gemma 3n configs differ enough (hidden_size 1024 vs 1536, head_dim 128 vs 192, different chunk sizes) that a generic shared module needs real design work to handle the parameterization cleanly. I'd rather do that in a follow-up PR once both implementations are in and we have two concrete call sites to factor against, instead of blocking either PR on it. Happy to co-author that extraction PR with you when the time comes. Also worth noting: #192 ships a complete mel spectrogram feature extractor ( |
Adds the shared MLXLMCommon hooks required to carry audio inputs from
UserInput through the processor to LanguageModel.prepare without
introducing any Hub or Tokenizers dependencies on the core module.
- LMInput.ProcessedAudio { features: MLXArray, mask: MLXArray? }
- LMInput.audio optional field + init parameter
- UserInput.audios: [[Float]] for raw 16kHz mono PCM
ed51bf6 to
339389e
Compare
|
Rebased this branch onto current What changed in this rebaseThe 5 commits are the same audio-tower series, replayed on top of the new base. The only non-mechanical change is the merge resolution in
While running the rebased test suite I also fixed one pre-existing flaky test in VerificationUnit tests — all Gemma 4 audio suites green after rebase: End-to-end — real-audio transcription against
Comparison against the pre-#211 baseline (
Happy to rerun against additional clips or add more assertion coverage if you'd like. |
This is a platx-ai-local integration test used to guard the Talk Gemma4 one-pass ASR pipeline against numerical or behavioural regressions when rebasing PR ml-explore#192 onto newer upstream/main. It loads mlx-community/gemma-4-e4b-it-4bit directly from the local ~/.cache/huggingface/hub/ snapshot (skips if not cached — never downloads) and transcribes a fixed 7.272s Mandarin clip, then prints timing and the generated tokens so we can spot-check diffs against a known baseline. Not intended for upstream. Kept on the platx/e2e-parity-test branch (not on feat/gemma4-audio-tower) so it never lands in PR ml-explore#192 while still being easy to check out when we need to re-verify Talk's end-to-end path. - New dev-only dep on swift-transformers for #huggingFaceTokenizerLoader() - Fixture: Tests/MLXLMTests/Fixtures/gemma4_e2e_audio.json (116352 samples, 7.272s @ 16kHz, decoded from a real Talk recording)
|
OK, parallel merge makes sense. I've rebased #194 on main and clean it up for review. For the mel extractor, I did a detailed comparison against the HuggingFace Python sources and found Gemma 3n and 4 differ more than expected: 32ms vs 20ms frames, 0.97 vs 0.0 preemphasis, standard vs offset Hann window, 1e-5 vs 1e-3 mel floor. Happy to help contribute to the consolidated extractor PR once both are in. |
@antmanler Thanks for the detailed breakdown. I agree that the shared conformer extraction is better as a follow-up with both implementations in the tree. Happy to co-author that. For the mel preprocessing I would appreciate your help porting |
Implements the audio preprocessing pipeline used by Gemma 4 audio tower: raw 16kHz mono PCM -> framing -> Hanning window -> vDSP RFFT -> HTK mel filter bank (128 bins, 0-8kHz) -> log mel -> MLXArray + mask. Matches the Python mlx-vlm reference implementation within 1e-3 absolute error on the mel fixture bundled with the tests. Key correctness fix: - vDSP FFT output is scaled by 0.5 relative to numpy.fft.rfft so we apply an explicit 0.5 scale when constructing the magnitude spectrum (line ~238). Without this the mel values are exactly 2x the Python reference. Not marked Sendable because MLXArray (the cached mel filter bank) is not Sendable under swift 6 strict concurrency.
Adds Gemma 4 audio input support on top of the text + vision + MoE implementation from ml-explore#180. New components: - Gemma4AudioConfiguration — decodes config.json's audio_config block, optional so models without audio still load unchanged. - Audio encoder (~700 lines) — USM-style Conformer with: * SubSampleConvProjection (2x Conv2d, 4x temporal reduction) * 12 ConformerBlock layers (FFW -> attention -> light conv1d -> FFW) * Relative positional embedding with sinusoidal projection * Chunked local self-attention with logit softcap * Cumulative group norm for streaming compatibility - Gemma4 model integration: * audio_tower and embed_audio modules, conditional on audio config * getInputEmbeddings extended with audioFeatures/audioMask, scatters encoder output into <|audio|> placeholder positions * prepare() routes LMInput.audio through the new path * sanitize() preserves audio weights only when audio is configured - Processor integration: * Gemma4Processor.prepare extracts mel features, computes the audio token count from actual subsampling math (not a duration estimate), injects <|audio|> placeholders into the prompt, and inverts the mel mask to True=padding convention expected by the encoder * Gemma4ProcessorConfiguration gains audioTokenId with default 258881 for models whose processor_config.json omits it Correctness fixes preserved from 668347f: 1. FFT 0.5 scaling (in Gemma4AudioFeatureExtractor, separate commit) 2. Audio mask polarity inversion (Processor.prepare) 3. Depthwise conv1d groups=hiddenSize in Gemma4ConformerLightConv1d 4. relPosInvTimescales stored as [Float] not MLXArray to avoid MLX Module weight lookup 5. Audio token count from actual mel subsampling, not duration 6. audioTokenId default 258881 when absent 7. Audio placeholder injection into chat template prompt Fix ml-explore#5 (EOS [1, 106, 50]) lives in the caller's generation loop, not in the library. File size: Gemma4.swift 1910 -> 2740 lines. Build: `swift build` passes on upstream/main base.
…ures
Adds two test suites covering the audio tower:
- Gemma4AudioTests.swift:
* audioConfigurationDecoding — JSON round-trip with defaults
* audioTokenMerging — token count matches subsampling math
* audioTokenCount — math-only bounds
* audioEncoderOutputShape — encoder forward shape
* melSpectrogramShape/Deterministic/FrameValues — mel preprocessing
* melFilterBankShape — filter bank sanity
* endToEndTranscription — config-level end-to-end integration
- Gemma4AudioAlignmentTest.swift:
* melOutputMatchesPython — fixture-based numerical parity against the
Python mlx-vlm reference within 1e-3 max abs error
* melSpectrogramShapeAndStats — mel shape + mean/std stability
Fixtures (gemma4_mel_reference.json, gemma4_mel_alignment.json,
gemma4_token_alignment.json, gemma4_e2e_reference.json) are registered
as processed resources on MLXLMTests in Package.swift.
Six SPM-runnable tests pass on both upstream main base and the
pre-rebase branch; the remaining tests need Metal and run under Xcode.
Without this fix, loading the gemma-4-e4b-it-4bit weights crashes with mismatchedSize: embed_audio.embedding_projection.biases expected [2560, 16] but got [2560, 24]. The encoder's actual output is the projected hidden_size (1536 for E4B via outputProjDims) rather than the bare 1024 hidden_size, so the embedder must project from the encoder output dim, not the bare hidden size. This was working correctly on the original 668347f branch but I dropped the outputProjDims fallback when reapplying the audio block onto upstream main. Restoring the original logic. Verified end-to-end via standalone SPM test loading gemma-4-e4b-it-4bit and transcribing a 7.27s reference clip: prepare: 0.129s generation: 0.539s (20 tokens) audioShape: [1, 726, 128] matches 668347f baseline transcript: 我现在已经切換到詹姆斯四。這個就是詹姆斯四截屏的短視頻。 baseline: 我现在已经切换到詹姆斯四。这个就是詹姆斯四接情的这段时期。
339389e to
d2e6490
Compare
| } | ||
| } | ||
|
|
||
| private final class Gemma4AudioRelativePositionEmbedding: Module { |
There was a problem hiding this comment.
This class is not used -- should it be (I wonder if it is a bug that it is not called)? If it isn't needed please remove.
| // This test will compile once AudioConfiguration is added to Gemma4.swift | ||
| // let config = try JSONDecoder().decode(Gemma4AudioConfig.self, from: json) | ||
| // #expect(config.numMelBins == 128) | ||
| // #expect(config.encoderLayers == 32) | ||
| // #expect(config.numAudioTokens == 750) |
There was a problem hiding this comment.
All of these assertions are commented out -- please update.
| imageTokenId = try c.decodeIfPresent(Int.self, forKey: CodingKeys.imageTokenId) ?? 258_880 | ||
| boiTokenId = try c.decodeIfPresent(Int.self, forKey: CodingKeys.boiTokenId) ?? 255_999 | ||
| eoiTokenId = try c.decodeIfPresent(Int.self, forKey: CodingKeys.eoiTokenId) ?? 258_882 | ||
| // Fix #7: default audioTokenId to 258881 when absent from processor_config.json |
There was a problem hiding this comment.
Please remove comments like this (seems to refer back to PR tasks?)
| public var videos = [Video]() | ||
|
|
||
| /// Audio inputs as raw PCM float arrays (16 kHz mono expected by Gemma 4). | ||
| public var audios = [[Float]]() |
There was a problem hiding this comment.
Please introduce an Audio type modeled after the Image/Video types.
There was a problem hiding this comment.
Also, this needs to be represented in ChatSession.
There was a problem hiding this comment.
The input is specific to gemma4 today, but much like the images it should be up to the input processor to convert to the desired format.
|
|
||
| // Relative position embedding (inline) | ||
| // Note: relPosInvTimescales is NOT a model parameter — it's a computed constant. | ||
| // Store as [Float] to avoid MLX Module treating it as a loadable weight. |
There was a problem hiding this comment.
You can write the property like this:
private let _relPosInvTimescalesData: MLXArrayand initialize the array -- model loading won't touch it because of the leading underscore.
| /// - audio: Raw waveform [numSamples] as Float array | ||
| /// - maxLength: Maximum number of samples (default 480000 = 30s at 16kHz) | ||
| /// - Returns: (melSpectrogram: [frames, featureSize], mask: [frames]) | ||
| public func extract(audio: [Float], maxLength: Int = 480_000) -> (MLXArray, MLXArray) { |
There was a problem hiding this comment.
Ideally the Audio type will produce an MLXArray and we can consume that here -- I think this code can probably be written to deal with the audio in that form and it will be executed on the GPU. See what you think.
| // Fix #6: audio token count from actual subsampling math (two 2x conv blocks) | ||
| let melFrames = melFeatures.dim(0) | ||
| let afterConv0 = (melFrames + 2 - 3) / 2 + 1 | ||
| let numAudioTokens = min((afterConv0 + 2 - 3) / 2 + 1, 750) |
There was a problem hiding this comment.
What does the 750 represent?
| // Fix #8: inject audio placeholder tokens into the prompt | ||
| // Insert before final assistant turn (last newline token 108) or append | ||
| let audioPlaceholders = Array(repeating: audioTokenId, count: numAudioTokens) | ||
| if let lastNewlineIdx = promptTokens.lastIndex(of: 108) { |
There was a problem hiding this comment.
What does the 108 represent?
| // Fix #2: mask polarity inversion — extractor outputs 1=valid but encoder expects True=padding | ||
| let invertedMask = melMask .== 0 |
There was a problem hiding this comment.
Can the mask be generated with the correct polarity in the first place? And as a bool MLXArray?
| let embedAudio, | ||
| let audioTokenId = config.audioTokenId | ||
| { | ||
| // audioFeatures: [1, frames, melBins] ; audioMask: [1, frames] (True=padding) |
There was a problem hiding this comment.
I think this path needs an equivalent to Gemma4Error.imageTokenCountMismatch -- it can fatal inside gemma4MaskedScatter but I think it should throw an error like it does in the image path.
| var processedAudio: LMInput.ProcessedAudio? = nil | ||
| if !input.audios.isEmpty, let audioTokenId = config.audioTokenId { | ||
| let extractor = Gemma4AudioFeatureExtractor() | ||
| let (melFeatures, melMask) = extractor.extract(audio: input.audios[0]) |
There was a problem hiding this comment.
Does this handle multiple audio files? Maybe it needs singleVideoAllowed (but for audio).
| if let audioFeatures, | ||
| let audioTower, | ||
| let embedAudio, | ||
| let audioTokenId = config.audioTokenId |
There was a problem hiding this comment.
What if they provide an audio resource but it doesn't have the audioTower etc.? Should this produce an error instead of silently dropping the resource?
| /// Load JSON fixture from Tests/MLXVLMTests/Fixtures/ | ||
| private func loadFixture(_ name: String) throws -> [String: Any] { | ||
| // #filePath points to source file location | ||
| let sourceDir = URL(fileURLWithPath: #filePath).deletingLastPathComponent() |
There was a problem hiding this comment.
Use this pattern:
func testVideoFileAsSimpleProcessedSequence() async throws {
guard let fileURL = Bundle.module.url(forResource: "1080p_30", withExtension: "mov") else {
XCTFail("Missing file: 1080p_30.mov")
return
}|
|
||
| // MARK: - Audio | ||
|
|
||
| private final class Gemma4AudioRMSNorm: Module, UnaryLayer { |
There was a problem hiding this comment.
This is identical to Gemma4RMSNormZeroShift -- do we need a separate type?
| /// Gemma4 audio feature extractor — converts raw waveform to log-mel spectrogram. | ||
| public struct Gemma4AudioFeatureExtractor { |
There was a problem hiding this comment.
This seems somewhat generic -- do other audio towers have something similar? I wonder if this is really Gemma4 specific?
And if not, perhaps it belongs more in MediaProcessing (maybe +Audio)?
There was a problem hiding this comment.
This file looks like it is not used. Only gemma4_mel_alignment.json is used.
| } | ||
|
|
||
| /// Compute mel spectrogram from windowed frames using Accelerate FFT. | ||
| private func computeMelSpectrogram(frames: [Float], numFrames: Int) -> MLXArray { |
There was a problem hiding this comment.
Again, if frames were an MLXArray then this could use MLX FFT (GPU).
Summary
Add audio support to Gemma 4 models (E2B and E4B), enabling native speech-to-text on Apple Silicon via MLX.
Built on top of PR #180 (Gemma 4 text + vision by @adrgrondin).
What's Added
Audio Encoder (
Gemma4.swift)Gemma4AudioConfiguration— all encoder parameters from model configGemma4SubSampleConvProjection— Conv2d blocks for mel frame downsamplingGemma4ConformerBlock— FFW → Attention → LightConv1d → FFW → RMSNormGemma4AudioAttention— chunked local attention with sinusoidal relative position embeddings and logit softcappingGemma4ConformerLightConv1d— GLU + depthwise causal conv1dGemma4AudioEncoder— SSCP + 12 Conformer layers + output projectionAudio Feature Extractor (
Gemma4AudioFeatureExtractor.swift)Gemma4AudioFeatureExtractorparametersIntegration
LMInput.ProcessedAudio— new struct for audio features + maskUserInput.audios— audio input field (raw PCM float arrays)Gemma4Processor.prepare()— audio token injection + mel extractionGemma4.prepare(LMInput)— routes audio features through encodergetInputEmbeddings()— scatters audio embeddings at<|audio|>positionsKey Implementation Details
relPosInvTimescalesstored as[Float](not MLXArray) to avoid Module parameter lookupgroups=hiddenSizefor correct weight shape[1024, 5, 1]Testing
Tests/MLXVLMTests/Fixtures/Models Tested
mlx-community/gemma-4-e2b-it-4bit(2B, ~3.4GB)mlx-community/gemma-4-e4b-it-4bit(4B, ~5.2GB)Dependencies