Skip to content

Add Gemma 4 video tower support#256

Open
john-rocky wants to merge 3 commits into
ml-explore:mainfrom
john-rocky:feat/gemma4-video
Open

Add Gemma 4 video tower support#256
john-rocky wants to merge 3 commits into
ml-explore:mainfrom
john-rocky:feat/gemma4-video

Conversation

@john-rocky
Copy link
Copy Markdown
Contributor

@john-rocky john-rocky commented May 1, 2026

Summary

Routes Gemma 4 video frames through the existing `vision_tower` as a batch (same encoder as images, just with a smaller per-frame soft-token budget) and performs text-level expansion of a `<|video|>` placeholder into a per-frame timestamped sequence — `mm:ss <start_of_image><|video|>×N<end_of_image>`, joined by spaces — matching the HF Gemma 4 reference processor.

Faithful to the Python reference in `mlx-vlm` Gemma 4 processor; the Swift-side simplification is a fixed-square video frame size (default 384×384) rather than aspect-preserving resize, so the existing pooler kernel grid (24/3 = 8) divides cleanly and per-frame token counts match the prompt placeholders exactly.

What's added

  • `Gemma4Configuration` gains `video_token_id`, `vision_soft_tokens_per_video_frame` (default 64), and `video_frame_chunk_size` (default 4).
  • `Gemma4ProcessorConfiguration` gains `video_seq_length`, `video_frame_size`, `video_max_frames` (default 16), `video_fps` (default 2.0), `boi_token`, `eoi_token`.
  • `Gemma4VisionModel.callAsFunction` accepts an `outputLength` parameter that is forwarded to the pooler so video frames pool down to a smaller token budget than images.
  • `Gemma4.getInputEmbeddings` accepts `pixelValuesVideos` and processes them in chunks of `video_frame_chunk_size` through the vision tower with `outputLength = visionSoftTokensPerVideoFrame`, calling `eval` between chunks so per-chunk attention masks can be released before the next chunk is built. Features are concatenated and scattered at `video_token_id` positions in the prompt embedding.
  • `Gemma4Processor.prepare` samples frames at `video_fps` capped at `video_max_frames`, processes each through `CIImage.toSRGB().resampled(:method:).normalized(:_:)`, and rewrites the decoded prompt to splice in the per-frame timestamped soft-token sequence before re-encoding.
  • `Gemma4MessageGenerator` emits one `<|video|>` literal text marker per video so the chat template tokenizes them as plain text and the processor expands them at the text level.

Why chunking + lower default frame count

The vision tower allocates an attention mask of shape `(batch, 1, max_patches, max_patches)`. With `max_patches = 2520` the mask alone is ~13 MB per frame, so a 32-frame batch peaks at ~400 MB before any layer activations — enough to OOM `gemma4-E2B` on iPhone alongside the language model. The chunked dispatch keeps peak memory roughly constant in frame count, and dropping the default to 16 frames adds headroom without sacrificing useful temporal coverage at `fps = 2`.

The `CIImage` chain change (`toSRGB().resampled().normalized()` instead of `MediaProcessing.apply` followed by `resampleBicubic`) avoids `MediaProcessing.apply`'s `bestFitScale` `transformed(by:)` step, which leaves the CIImage with a non-square extent that round-trips fine through most paths but trips `asMLXArray`'s bitmap render on iOS with `verify_image_parameters: invalid image bits/pixel or bytes/row`.

Token math

A 384×384 frame yields 24×24 = 576 patches. With Gemma 4's 3×3 pooling, kernel = √(576/64) = 3 (clean), output = 8×8 = 64 soft tokens per frame. `videoSeqLength` and `visionSoftTokensPerVideoFrame` both default to 64; override either via JSON if the model is published with a different budget — they must match.

Verification

  • ✅ `swift build` clean.
  • ✅ `xcodebuild MLXChatExample` builds for macOS arm64, generic iOS Simulator, and generic iOS Device against this branch.
  • ✅ App launches on macOS without crash.
  • On-device iOS verified — iPhone 17 Pro, `mlx-community/gemma-4-e2b-it-4bit`, 8 s video clip → 42.53 tok/s decode, transcript matched the video content. (After the chunking + frame-count + CIImage chain fixes; an earlier draft of this PR with `videoMaxFrames = 32` and the `MediaProcessing.apply`-based preprocess OOM'd / hit `verify_image_parameters` on the same device.)
  • ⚠️ Companion ml-explore/mlx-swift-examples PR registers `gemma4:E2B` / `gemma4:E4B` in MLXChatExample.

Notes

  • Image path is unchanged.
  • Audio path (PR Add Gemma 4 audio tower support (ASR via Conformer encoder) #192) is orthogonal; this PR does not interact with `audio_tower`.
  • The fixed video frame size is a deliberate Swift-side simplification of Python's aspect-preserving 70-token-budget pipeline. If a downstream model needs a different soft-token budget, override `vision_soft_tokens_per_video_frame` and `video_seq_length` together.

Routes video frames through the existing vision_tower as a batch (same
encoder as images, just with a smaller per-frame soft token budget) and
performs text-level expansion of the `<|video|>` placeholder into a
per-frame `mm:ss <start_of_image><|video|>x N<end_of_image>` sequence
matching the HF Gemma 4 processor reference implementation.

Implementation notes:

- `Gemma4Configuration` gains `video_token_id` and
  `vision_soft_tokens_per_video_frame` (default 64).
- `Gemma4ProcessorConfiguration` gains `video_seq_length`,
  `video_frame_size`, `video_max_frames`, `video_fps`, and
  `boi_token`/`eoi_token` strings.
- `Gemma4VisionModel` accepts an `outputLength` parameter that is
  forwarded to the pooler so video frames pool down to a smaller token
  budget than images.
- `Gemma4.getInputEmbeddings` accepts `pixelValuesVideos`, runs them
  through the vision tower with `outputLength=visionSoftTokensPerVideoFrame`
  and scatters the resulting features at `video_token_id` positions in
  the prompt embedding.
- `Gemma4Processor.prepare` samples frames at `video_fps` capped at
  `video_max_frames`, resizes each to the fixed video frame size
  (default 384x384, which yields exactly 64 soft tokens per frame after
  3x3 pooling), and rewrites the decoded prompt to splice in the
  per-frame timestamped soft-token sequence before re-encoding.
- `Gemma4MessageGenerator` emits one `<|video|>` literal text marker
  per video so the chat template tokenizes them as plain text, then the
  processor expands them at the text level.

Faithful to the Python reference in
mlx-vlm/mlx_vlm/models/gemma4/processing_gemma4.py and gemma4.py;
the only Swift-side simplification is using a fixed square video frame
size (384x384) instead of an aspect-preserving resize. This keeps the
existing pooler kernel grid clean (24/3 = 8) so per-frame token counts
match the prompt placeholders exactly.
Two changes addressing crashes when the new video tower runs on
iPhone-class memory budgets:

1. Chunk the vision tower batch dim. The vision tower allocates an
   attention mask of shape (batch, 1, max_patches, max_patches); for
   max_patches=2520 and batch=32 the mask alone is ~400 MB, easily
   enough to OOM `gemma4-E2B` on iOS. `getInputEmbeddings` now walks
   `pixel_values_videos` in `videoFrameChunkSize` slices (default 4),
   evaluates each slice's projected features, then concatenates. Peak
   memory stays roughly constant in the number of frames sampled.

2. Lower default `videoMaxFrames` 32 → 16 so even pathological cases
   (chunk size override, very long videos) stay below the iOS budget.

Also replace the per-frame preprocessing chain with the proven
SmolVLM2-style chain (`CIImage.toSRGB().resampled(_:method:).normalized(_:_:)`).
The previous chain used `MediaProcessing.apply` first, whose
`bestFitScale` `transformed(by:)` leaves the CIImage with a non-square
extent that round-trips fine through most paths but trips
`asMLXArray`'s bitmap render on iOS with
`verify_image_parameters: invalid image bits/pixel or bytes/row`.

Both changes default-on; advanced users can override
`video_frame_chunk_size` and `video_max_frames` in the config.
Covers the previously-uncovered public surface from ml-explore#256:

- Gemma4ProcessorConfiguration decodes the new video_token_id /
  video_seq_length / video_frame_size / video_max_frames / video_fps
  fields, with iOS-safe defaults (videoSeqLength=64, videoMaxFrames=16,
  videoFps=2.0) applied when an older preprocessor_config.json omits
  them.
- Gemma4MessageGenerator prefixes one '<|video|>' placeholder per video
  to the text block, leaving the message text alone when no videos are
  attached.

4 new tests, all passing.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant