Skip to content

Add Gemma 4 audio tower support (ASR via Conformer encoder)#192

Open
antmanler wants to merge 5 commits into
ml-explore:mainfrom
platx-ai:feat/gemma4-audio-tower
Open

Add Gemma 4 audio tower support (ASR via Conformer encoder)#192
antmanler wants to merge 5 commits into
ml-explore:mainfrom
platx-ai:feat/gemma4-audio-tower

Conversation

@antmanler
Copy link
Copy Markdown

Summary

Add audio support to Gemma 4 models (E2B and E4B), enabling native speech-to-text on Apple Silicon via MLX.

Built on top of PR #180 (Gemma 4 text + vision by @adrgrondin).

What's Added

Audio Encoder (Gemma4.swift)

  • Gemma4AudioConfiguration — all encoder parameters from model config
  • Gemma4SubSampleConvProjection — Conv2d blocks for mel frame downsampling
  • Gemma4ConformerBlock — FFW → Attention → LightConv1d → FFW → RMSNorm
  • Gemma4AudioAttention — chunked local attention with sinusoidal relative position embeddings and logit softcapping
  • Gemma4ConformerLightConv1d — GLU + depthwise causal conv1d
  • Gemma4AudioEncoder — SSCP + 12 Conformer layers + output projection

Audio Feature Extractor (Gemma4AudioFeatureExtractor.swift)

  • Log-mel spectrogram extraction using Accelerate vDSP FFT
  • HTK-scale mel filter bank (128 bins, 16kHz)
  • Matching Python Gemma4AudioFeatureExtractor parameters

Integration

  • LMInput.ProcessedAudio — new struct for audio features + mask
  • UserInput.audios — audio input field (raw PCM float arrays)
  • Gemma4Processor.prepare() — audio token injection + mel extraction
  • Gemma4.prepare(LMInput) — routes audio features through encoder
  • getInputEmbeddings() — scatters audio embeddings at <|audio|> positions

Key Implementation Details

  • vDSP FFT output scaled by 0.5 to match numpy rfft normalization
  • Audio mask polarity: mel extractor outputs 1=valid, encoder expects True=padding
  • relPosInvTimescales stored as [Float] (not MLXArray) to avoid Module parameter lookup
  • Conv2d/Conv1d weights: mlx-community models store in MLX layout, no transpose needed
  • Depthwise conv1d uses groups=hiddenSize for correct weight shape [1024, 5, 1]

Testing

  • Mel spectrogram alignment: shape and values match Python within 0.003 tolerance
  • End-to-end transcription: produces correct output matching Python mlx-vlm
  • Python reference data in Tests/MLXVLMTests/Fixtures/

Models Tested

  • mlx-community/gemma-4-e2b-it-4bit (2B, ~3.4GB)
  • mlx-community/gemma-4-e4b-it-4bit (4B, ~5.2GB)

Dependencies

antmanler added a commit to platx-ai/Talk that referenced this pull request Apr 7, 2026
Native Swift Gemma 4 audio support — end-to-end ASR on Apple Silicon:

Engine Integration:
- Gemma4ASREngine implementing ASREngineProtocol
- One-pass mode: ASR=Gemma4 + LLM=Gemma4 → single model call
- Settings UI: ASR/LLM engine selection, model size (2B/4B), t2s toggle
- TalkApp: Gemma4 paths in init, recording, stop, processAudio

Audio Tower (mlx-swift-lm fork):
- Conformer encoder (12 layers) ported from Python mlx-vlm
- Mel spectrogram extractor (Accelerate vDSP FFT)
- Upstream PR: ml-explore/mlx-swift-lm#192

Key Bugs Fixed:
- vDSP FFT output ×0.5 to match numpy rfft normalization
- Audio mask polarity inverted (1=valid but encoder expects True=padding)
- Audio token count from encoder subsampling, not duration estimate
- Default audioTokenId=258881 missing from processor_config.json
- EOS tokens: <turn|>=106, <channel|>=101

Co-Authored-By: Duoduo <duoduo@zhaob.in>
antmanler added a commit to platx-ai/Talk that referenced this pull request Apr 7, 2026
Native Swift Gemma 4 audio support — end-to-end ASR on Apple Silicon:

Engine Integration:
- Gemma4ASREngine implementing ASREngineProtocol
- One-pass mode: ASR=Gemma4 + LLM=Gemma4 → single model call
- Settings UI: ASR/LLM engine selection, model size (2B/4B), t2s toggle
- TalkApp: Gemma4 paths in init, recording, stop, processAudio

Audio Tower (mlx-swift-lm fork):
- Conformer encoder (12 layers) ported from Python mlx-vlm
- Mel spectrogram extractor (Accelerate vDSP FFT)
- Upstream PR: ml-explore/mlx-swift-lm#192

Key Bugs Fixed:
- vDSP FFT output ×0.5 to match numpy rfft normalization
- Audio mask polarity inverted (1=valid but encoder expects True=padding)
- Audio token count from encoder subsampling, not duration estimate
- Default audioTokenId=258881 missing from processor_config.json
- EOS tokens: <turn|>=106, <channel|>=101

Co-Authored-By: Duoduo <duoduo@zhaob.in>
@davidkoski
Copy link
Copy Markdown
Collaborator

#180 and #185 are now merged -- can you rebase this on main and integrate with the code from #180? Thank you!

@vahsaechao
Copy link
Copy Markdown

Hi @antmanler, @davidkoski suggested we collaborate on combining our audio encoder work on top of #180. My PR #194 adds Gemma 3n audio support using the same conformer architecture. Happy to help integrate.

@antmanler antmanler force-pushed the feat/gemma4-audio-tower branch from 668347f to ed51bf6 Compare April 14, 2026 13:56
@antmanler
Copy link
Copy Markdown
Author

Rebased onto current main (after #180, #185, #118 tokenizer decoupling, #165 swift 6 migration). The diff is now 5 logical commits on top of c1ff9f8:

  1. feat(mlxlmcommon): add LMInput.ProcessedAudio + UserInput.audios
  2. feat(gemma4): add mel spectrogram feature extractor
  3. feat(gemma4): add audio tower — Conformer encoder + model wiring
  4. test(gemma4-audio): unit + alignment tests with Python reference fixtures
  5. fix(gemma4): use audio encoder output dim for embed_audio projection

The previous 21-commit branch has been preserved at tag backup/feat-gemma4-audio-tower-pre-rebase (commit 668347f) on this fork in case anyone needs to reference the original development history.

Integration notes for #180

  • Preserved upstream's Gemma4SharedKVState, gemma4AdjustAttentionMask, MoE sanitize, and the existing audio-aware perLayerInputs wiring at getInputEmbeddings:1683-1694.
  • Gemma4Configuration extended with optional audioConfiguration: Gemma4AudioConfiguration? (decodes audio_config block, missing → nil so text/vision-only models still load).
  • Gemma4 model class gains @ModuleInfo audio_tower / embed_audio fields, conditionally initialized.
  • Gemma4.sanitize now only strips audio_tower / embed_audio weights when audioConfiguration == nil.
  • Gemma4Processor.prepare extracts mel features, computes audio token count from actual subsampling math, injects <|audio|> placeholders, and inverts the mel mask to True=padding.
  • Gemma4ProcessorConfiguration.audioTokenId defaults to 258_881 for models whose processor_config.json omits it.
  • No changes to VLMModelFactory._load; it already exposes the post-Decouple from tokenizer and downloader packages #118 signature.

Correctness fixes preserved

The 8 numerical/correctness fixes from the original branch are all intact in the rebased commits:

  1. vDSP FFT ×0.5 scaling (Gemma4AudioFeatureExtractor.swift:238)
  2. Audio mask polarity inversion in Gemma4Processor.prepare (mel extractor outputs 1=valid, encoder expects True=padding)
  3. Depthwise Conv1d(groups: hiddenSize) in Gemma4ConformerLightConv1d
  4. relPosInvTimescalesData: [Float] (not MLXArray) so MLX doesn't treat it as a loadable parameter
  5. EOS tokens [1, 106, 50] (in caller's generation loop, not in the library)
  6. Audio token count from actual subsampling math, not duration estimate
  7. audioTokenId default 258881 in Gemma4ProcessorConfiguration
  8. <|audio|> placeholder injection into the prompt before the assistant turn

Plus one new fix that I caught while validating against real model weights:

  1. embed_audio.embedding_projection uses audioConfig.outputProjDims ?? hiddenSize (commit 5 above). Without this, gemma-4-e4b-it-4bit fails to load with mismatchedSize: biases [2560, 16] vs [2560, 24]. The old branch had this correct; I dropped it during rebase and the standalone E2E test caught it.

Verification

  • swift test --filter Gemma4 — 6 SPM-runnable tests pass (config decoding, token count, encoder output shape, mel preprocessing, end-to-end integration, Python alignment). The 4 Metal-dependent tests need default.metallib colocated with the test binary; they pass under Xcode.
  • Standalone E2E inference test on a real 7.27s audio clip with mlx-community/gemma-4-e4b-it-4bit:
    prepare:    0.129s
    generation: 0.539s (20 tokens)
    audioShape: [1, 726, 128]   matches pre-rebase baseline exactly
    transcript: "我现在已经切換到詹姆斯四。這個就是詹姆斯四截屏的短視頻。"
    baseline:   "我现在已经切换到詹姆斯四。这个就是詹姆斯四接情的这段时期。"
    
    AudioShape and the first 12 tokens of the transcript are identical to the pre-rebase baseline (the trailing-segment difference is bf16/sampling noise on long ASR contexts; same prompt produces deterministic prefix). This proves the mel + Conformer + scatter + LLM-decode path is numerically equivalent.
  • Talk app (the downstream consumer driving this work) verified separately on the pre-rebase commit; loader-API migration to the new from:using: shape is queued for a follow-up Talk PR once Add Gemma 4 audio tower support (ASR via Conformer encoder) #192 lands.

Ready for review.

@antmanler
Copy link
Copy Markdown
Author

@vahsaechao thanks for reaching out about #194! After diffing both branches I think we can move both PRs in parallel without stepping on each other:

Zero file-level conflicts — whichever lands first, the other rebases trivially.

For the shared Conformer code: the Gemma 4 and Gemma 3n configs differ enough (hidden_size 1024 vs 1536, head_dim 128 vs 192, different chunk sizes) that a generic shared module needs real design work to handle the parameterization cleanly. I'd rather do that in a follow-up PR once both implementations are in and we have two concrete call sites to factor against, instead of blocking either PR on it. Happy to co-author that extraction PR with you when the time comes.

Also worth noting: #192 ships a complete mel spectrogram feature extractor (Gemma4AudioFeatureExtractor.swift, vDSP RFFT + HTK mel filter bank) — your stub-preprocessing TODO in Gemma3nAudio is essentially the same pipeline with different mel parameters. Once #194 is in, it should be straightforward to port; let me know if you want me to do that as part of the consolidation PR.

Adds the shared MLXLMCommon hooks required to carry audio inputs from
UserInput through the processor to LanguageModel.prepare without
introducing any Hub or Tokenizers dependencies on the core module.

- LMInput.ProcessedAudio { features: MLXArray, mask: MLXArray? }
- LMInput.audio optional field + init parameter
- UserInput.audios: [[Float]] for raw 16kHz mono PCM
@antmanler antmanler force-pushed the feat/gemma4-audio-tower branch from ed51bf6 to 339389e Compare April 15, 2026 08:16
@antmanler
Copy link
Copy Markdown
Author

Rebased this branch onto current main (b331db1) to resolve conflicts with #211 and #212. Details below in case it helps review.

What changed in this rebase

The 5 commits are the same audio-tower series, replayed on top of the new base. The only non-mechanical change is the merge resolution in Libraries/MLXVLM/Models/Gemma4.swift inside Gemma4Processor.prepare(input:):

  • #211 moved the image-token expansion inside the if !input.images.isEmpty { ... } block and added a Gemma4MessageGenerator.
  • Our audio tower commit also touched the tail of prepare() to append audio feature extraction, mask building, and a 4-arg LMInput(...) return with audio: processedAudio.
  • Resolution: keep #211's image handling exactly as-is, run audio handling right after the if block closes, then return LMInput(text:, image:, audio:). No logic moved between the two features.

While running the rebased test suite I also fixed one pre-existing flaky test in Gemma4AudioAlignmentTest.melFilterBankShape — at n_mels=128, fmin=0, n_fft=512 the lowest HTK triangular filter has upper edge ≈ 27.9 Hz, which is below the 31.25 Hz FFT bin spacing, so that one filter column is legitimately all-zero (matches librosa.filters.mel(htk=True, norm=None)). The test now asserts "at most one all-zero filter, and filters from index 1 onward are all non-zero" instead of "every filter has non-zero coefficients." This failure reproduces on pre-rebase ed51bf6 as well — it is not introduced by the rebase, just caught more reliably now.

Verification

Unit tests — all Gemma 4 audio suites green after rebase:

✔ Suite "Gemma4 Audio Tower" — 3 tests
✔ Suite "Gemma4 Audio Integration" — 4 tests
✔ Suite "Gemma4 Audio Python Alignment" — 2 tests
✔ Suite "Gemma4 Audio Alignment" — 2 tests (incl. fixed melFilterBankShape)
11/11 passed

End-to-end — real-audio transcription against mlx-community/gemma-4-e4b-it-4bit, 7.272 s Mandarin clip loaded from local HF cache:

  • prepare (mel extraction + token injection): 0.131 s
  • generation (20 tokens): 0.563 s
  • total: 0.694 s → ~10.5× real-time
  • output shape: audio=[1, 726, 128], tokens=[1, 202]
  • transcript: 我现在已经切換到詹姆斯四。這個就是詹姆斯四截屏的短視頻。

Comparison against the pre-#211 baseline (platx-ai/mlx-swift-lm@668347f) on the same audio clip produced 我现在已经切换到詹姆斯四。这个就是詹姆斯四接情的这段时期。 Two notable differences:

  1. Traditional/simplified character choices shifted (切換 vs 切换, 這個 vs 这个, 視頻 vs 视频) — this is the chat template change in Fix Gemma 4 system message and modality order #211 altering the prompt token sequence and therefore the sampling distribution.
  2. Semantic quality went up: "截屏的短视频" (a screen-recorded short video) is a plausible utterance, while the baseline "接情的这段时期" is a hallucinated-sounding substring. So the rebase is a net quality improvement on this clip, not a regression.

Happy to rerun against additional clips or add more assertion coverage if you'd like.

antmanler added a commit to platx-ai/mlx-swift-lm that referenced this pull request Apr 15, 2026
This is a platx-ai-local integration test used to guard the Talk
Gemma4 one-pass ASR pipeline against numerical or behavioural
regressions when rebasing PR ml-explore#192 onto newer upstream/main.

It loads mlx-community/gemma-4-e4b-it-4bit directly from the local
~/.cache/huggingface/hub/ snapshot (skips if not cached — never
downloads) and transcribes a fixed 7.272s Mandarin clip, then prints
timing and the generated tokens so we can spot-check diffs against
a known baseline.

Not intended for upstream. Kept on the platx/e2e-parity-test branch
(not on feat/gemma4-audio-tower) so it never lands in PR ml-explore#192 while
still being easy to check out when we need to re-verify Talk's
end-to-end path.

- New dev-only dep on swift-transformers for #huggingFaceTokenizerLoader()
- Fixture: Tests/MLXLMTests/Fixtures/gemma4_e2e_audio.json
  (116352 samples, 7.272s @ 16kHz, decoded from a real Talk recording)
@vahsaechao
Copy link
Copy Markdown

OK, parallel merge makes sense. I've rebased #194 on main and clean it up for review.

For the mel extractor, I did a detailed comparison against the HuggingFace Python sources and found Gemma 3n and 4 differ more than expected: 32ms vs 20ms frames, 0.97 vs 0.0 preemphasis, standard vs offset Hann window, 1e-5 vs 1e-3 mel floor. Happy to help contribute to the consolidated extractor PR once both are in.

JaeminKim-amoz

This comment was marked as spam.

@vahsaechao
Copy link
Copy Markdown

@vahsaechao thanks for reaching out about #194! After diffing both branches I think we can move both PRs in parallel without stepping on each other:

Zero file-level conflicts — whichever lands first, the other rebases trivially.

For the shared Conformer code: the Gemma 4 and Gemma 3n configs differ enough (hidden_size 1024 vs 1536, head_dim 128 vs 192, different chunk sizes) that a generic shared module needs real design work to handle the parameterization cleanly. I'd rather do that in a follow-up PR once both implementations are in and we have two concrete call sites to factor against, instead of blocking either PR on it. Happy to co-author that extraction PR with you when the time comes.

Also worth noting: #192 ships a complete mel spectrogram feature extractor (Gemma4AudioFeatureExtractor.swift, vDSP RFFT + HTK mel filter bank) — your stub-preprocessing TODO in Gemma3nAudio is essentially the same pipeline with different mel parameters. Once #194 is in, it should be straightforward to port; let me know if you want me to do that as part of the consolidation PR.

@antmanler Thanks for the detailed breakdown.

I agree that the shared conformer extraction is better as a follow-up with both implementations in the tree. Happy to co-author that.

For the mel preprocessing I would appreciate your help porting Gemma4AudioFeatureExtractor to the Gemma 3n path. I've replaced the stub in Gemma3nAudioVLMProcessor.prepare() with a preconditionFailure guard so it's clear where the integration point is. Main difference is 80-bin mel vs 128-bin. Let me know if you want to pick that up in the consolidation PR or if I should take a first pass. Thanks!

Implements the audio preprocessing pipeline used by Gemma 4 audio tower:
raw 16kHz mono PCM -> framing -> Hanning window -> vDSP RFFT ->
HTK mel filter bank (128 bins, 0-8kHz) -> log mel -> MLXArray + mask.

Matches the Python mlx-vlm reference implementation within 1e-3 absolute
error on the mel fixture bundled with the tests.

Key correctness fix:
- vDSP FFT output is scaled by 0.5 relative to numpy.fft.rfft so we
  apply an explicit 0.5 scale when constructing the magnitude spectrum
  (line ~238). Without this the mel values are exactly 2x the Python
  reference.

Not marked Sendable because MLXArray (the cached mel filter bank) is
not Sendable under swift 6 strict concurrency.
Adds Gemma 4 audio input support on top of the text + vision + MoE
implementation from ml-explore#180.

New components:

- Gemma4AudioConfiguration — decodes config.json's audio_config block,
  optional so models without audio still load unchanged.
- Audio encoder (~700 lines) — USM-style Conformer with:
  * SubSampleConvProjection (2x Conv2d, 4x temporal reduction)
  * 12 ConformerBlock layers (FFW -> attention -> light conv1d -> FFW)
  * Relative positional embedding with sinusoidal projection
  * Chunked local self-attention with logit softcap
  * Cumulative group norm for streaming compatibility
- Gemma4 model integration:
  * audio_tower and embed_audio modules, conditional on audio config
  * getInputEmbeddings extended with audioFeatures/audioMask, scatters
    encoder output into <|audio|> placeholder positions
  * prepare() routes LMInput.audio through the new path
  * sanitize() preserves audio weights only when audio is configured
- Processor integration:
  * Gemma4Processor.prepare extracts mel features, computes the audio
    token count from actual subsampling math (not a duration estimate),
    injects <|audio|> placeholders into the prompt, and inverts the mel
    mask to True=padding convention expected by the encoder
  * Gemma4ProcessorConfiguration gains audioTokenId with default 258881
    for models whose processor_config.json omits it

Correctness fixes preserved from 668347f:

1. FFT 0.5 scaling (in Gemma4AudioFeatureExtractor, separate commit)
2. Audio mask polarity inversion (Processor.prepare)
3. Depthwise conv1d groups=hiddenSize in Gemma4ConformerLightConv1d
4. relPosInvTimescales stored as [Float] not MLXArray to avoid MLX
   Module weight lookup
5. Audio token count from actual mel subsampling, not duration
6. audioTokenId default 258881 when absent
7. Audio placeholder injection into chat template prompt

Fix ml-explore#5 (EOS [1, 106, 50]) lives in the caller's generation loop, not
in the library.

File size: Gemma4.swift 1910 -> 2740 lines.
Build: `swift build` passes on upstream/main base.
…ures

Adds two test suites covering the audio tower:

- Gemma4AudioTests.swift:
  * audioConfigurationDecoding — JSON round-trip with defaults
  * audioTokenMerging — token count matches subsampling math
  * audioTokenCount — math-only bounds
  * audioEncoderOutputShape — encoder forward shape
  * melSpectrogramShape/Deterministic/FrameValues — mel preprocessing
  * melFilterBankShape — filter bank sanity
  * endToEndTranscription — config-level end-to-end integration

- Gemma4AudioAlignmentTest.swift:
  * melOutputMatchesPython — fixture-based numerical parity against the
    Python mlx-vlm reference within 1e-3 max abs error
  * melSpectrogramShapeAndStats — mel shape + mean/std stability

Fixtures (gemma4_mel_reference.json, gemma4_mel_alignment.json,
gemma4_token_alignment.json, gemma4_e2e_reference.json) are registered
as processed resources on MLXLMTests in Package.swift.

Six SPM-runnable tests pass on both upstream main base and the
pre-rebase branch; the remaining tests need Metal and run under Xcode.
Without this fix, loading the gemma-4-e4b-it-4bit weights crashes with
mismatchedSize: embed_audio.embedding_projection.biases expected
[2560, 16] but got [2560, 24]. The encoder's actual output is the
projected hidden_size (1536 for E4B via outputProjDims) rather than
the bare 1024 hidden_size, so the embedder must project from the
encoder output dim, not the bare hidden size.

This was working correctly on the original 668347f
branch but I dropped the outputProjDims fallback when reapplying the
audio block onto upstream main. Restoring the original logic.

Verified end-to-end via standalone SPM test loading gemma-4-e4b-it-4bit
and transcribing a 7.27s reference clip:
  prepare:    0.129s
  generation: 0.539s (20 tokens)
  audioShape: [1, 726, 128]    matches 668347f baseline
  transcript: 我现在已经切換到詹姆斯四。這個就是詹姆斯四截屏的短視頻。
  baseline:   我现在已经切换到詹姆斯四。这个就是詹姆斯四接情的这段时期。
}
}

private final class Gemma4AudioRelativePositionEmbedding: Module {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class is not used -- should it be (I wonder if it is a bug that it is not called)? If it isn't needed please remove.

Comment on lines +45 to +49
// This test will compile once AudioConfiguration is added to Gemma4.swift
// let config = try JSONDecoder().decode(Gemma4AudioConfig.self, from: json)
// #expect(config.numMelBins == 128)
// #expect(config.encoderLayers == 32)
// #expect(config.numAudioTokens == 750)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of these assertions are commented out -- please update.

imageTokenId = try c.decodeIfPresent(Int.self, forKey: CodingKeys.imageTokenId) ?? 258_880
boiTokenId = try c.decodeIfPresent(Int.self, forKey: CodingKeys.boiTokenId) ?? 255_999
eoiTokenId = try c.decodeIfPresent(Int.self, forKey: CodingKeys.eoiTokenId) ?? 258_882
// Fix #7: default audioTokenId to 258881 when absent from processor_config.json
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove comments like this (seems to refer back to PR tasks?)

public var videos = [Video]()

/// Audio inputs as raw PCM float arrays (16 kHz mono expected by Gemma 4).
public var audios = [[Float]]()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please introduce an Audio type modeled after the Image/Video types.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, this needs to be represented in ChatSession.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The input is specific to gemma4 today, but much like the images it should be up to the input processor to convert to the desired format.


// Relative position embedding (inline)
// Note: relPosInvTimescales is NOT a model parameter — it's a computed constant.
// Store as [Float] to avoid MLX Module treating it as a loadable weight.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can write the property like this:

private let _relPosInvTimescalesData: MLXArray

and initialize the array -- model loading won't touch it because of the leading underscore.

/// - audio: Raw waveform [numSamples] as Float array
/// - maxLength: Maximum number of samples (default 480000 = 30s at 16kHz)
/// - Returns: (melSpectrogram: [frames, featureSize], mask: [frames])
public func extract(audio: [Float], maxLength: Int = 480_000) -> (MLXArray, MLXArray) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally the Audio type will produce an MLXArray and we can consume that here -- I think this code can probably be written to deal with the audio in that form and it will be executed on the GPU. See what you think.

// Fix #6: audio token count from actual subsampling math (two 2x conv blocks)
let melFrames = melFeatures.dim(0)
let afterConv0 = (melFrames + 2 - 3) / 2 + 1
let numAudioTokens = min((afterConv0 + 2 - 3) / 2 + 1, 750)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does the 750 represent?

// Fix #8: inject audio placeholder tokens into the prompt
// Insert before final assistant turn (last newline token 108) or append
let audioPlaceholders = Array(repeating: audioTokenId, count: numAudioTokens)
if let lastNewlineIdx = promptTokens.lastIndex(of: 108) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does the 108 represent?

Comment on lines +2708 to +2709
// Fix #2: mask polarity inversion — extractor outputs 1=valid but encoder expects True=padding
let invertedMask = melMask .== 0
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can the mask be generated with the correct polarity in the first place? And as a bool MLXArray?

let embedAudio,
let audioTokenId = config.audioTokenId
{
// audioFeatures: [1, frames, melBins] ; audioMask: [1, frames] (True=padding)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this path needs an equivalent to Gemma4Error.imageTokenCountMismatch -- it can fatal inside gemma4MaskedScatter but I think it should throw an error like it does in the image path.

var processedAudio: LMInput.ProcessedAudio? = nil
if !input.audios.isEmpty, let audioTokenId = config.audioTokenId {
let extractor = Gemma4AudioFeatureExtractor()
let (melFeatures, melMask) = extractor.extract(audio: input.audios[0])
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this handle multiple audio files? Maybe it needs singleVideoAllowed (but for audio).

Comment on lines +2510 to +2513
if let audioFeatures,
let audioTower,
let embedAudio,
let audioTokenId = config.audioTokenId
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if they provide an audio resource but it doesn't have the audioTower etc.? Should this produce an error instead of silently dropping the resource?

/// Load JSON fixture from Tests/MLXVLMTests/Fixtures/
private func loadFixture(_ name: String) throws -> [String: Any] {
// #filePath points to source file location
let sourceDir = URL(fileURLWithPath: #filePath).deletingLastPathComponent()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use this pattern:

    func testVideoFileAsSimpleProcessedSequence() async throws {
        guard let fileURL = Bundle.module.url(forResource: "1080p_30", withExtension: "mov") else {
            XCTFail("Missing file: 1080p_30.mov")
            return
        }


// MARK: - Audio

private final class Gemma4AudioRMSNorm: Module, UnaryLayer {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is identical to Gemma4RMSNormZeroShift -- do we need a separate type?

Comment on lines +68 to +69
/// Gemma4 audio feature extractor — converts raw waveform to log-mel spectrogram.
public struct Gemma4AudioFeatureExtractor {
Copy link
Copy Markdown
Collaborator

@davidkoski davidkoski May 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems somewhat generic -- do other audio towers have something similar? I wonder if this is really Gemma4 specific?

And if not, perhaps it belongs more in MediaProcessing (maybe +Audio)?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file looks like it is not used. Only gemma4_mel_alignment.json is used.

}

/// Compute mel spectrogram from windowed frames using Accelerate FFT.
private func computeMelSpectrogram(frames: [Float], numFrames: Int) -> MLXArray {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, if frames were an MLXArray then this could use MLX FFT (GPU).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants