Add Gemma 4 video tower support#256
Open
john-rocky wants to merge 3 commits into
Open
Conversation
Routes video frames through the existing vision_tower as a batch (same encoder as images, just with a smaller per-frame soft token budget) and performs text-level expansion of the `<|video|>` placeholder into a per-frame `mm:ss <start_of_image><|video|>x N<end_of_image>` sequence matching the HF Gemma 4 processor reference implementation. Implementation notes: - `Gemma4Configuration` gains `video_token_id` and `vision_soft_tokens_per_video_frame` (default 64). - `Gemma4ProcessorConfiguration` gains `video_seq_length`, `video_frame_size`, `video_max_frames`, `video_fps`, and `boi_token`/`eoi_token` strings. - `Gemma4VisionModel` accepts an `outputLength` parameter that is forwarded to the pooler so video frames pool down to a smaller token budget than images. - `Gemma4.getInputEmbeddings` accepts `pixelValuesVideos`, runs them through the vision tower with `outputLength=visionSoftTokensPerVideoFrame` and scatters the resulting features at `video_token_id` positions in the prompt embedding. - `Gemma4Processor.prepare` samples frames at `video_fps` capped at `video_max_frames`, resizes each to the fixed video frame size (default 384x384, which yields exactly 64 soft tokens per frame after 3x3 pooling), and rewrites the decoded prompt to splice in the per-frame timestamped soft-token sequence before re-encoding. - `Gemma4MessageGenerator` emits one `<|video|>` literal text marker per video so the chat template tokenizes them as plain text, then the processor expands them at the text level. Faithful to the Python reference in mlx-vlm/mlx_vlm/models/gemma4/processing_gemma4.py and gemma4.py; the only Swift-side simplification is using a fixed square video frame size (384x384) instead of an aspect-preserving resize. This keeps the existing pooler kernel grid clean (24/3 = 8) so per-frame token counts match the prompt placeholders exactly.
2 tasks
Two changes addressing crashes when the new video tower runs on iPhone-class memory budgets: 1. Chunk the vision tower batch dim. The vision tower allocates an attention mask of shape (batch, 1, max_patches, max_patches); for max_patches=2520 and batch=32 the mask alone is ~400 MB, easily enough to OOM `gemma4-E2B` on iOS. `getInputEmbeddings` now walks `pixel_values_videos` in `videoFrameChunkSize` slices (default 4), evaluates each slice's projected features, then concatenates. Peak memory stays roughly constant in the number of frames sampled. 2. Lower default `videoMaxFrames` 32 → 16 so even pathological cases (chunk size override, very long videos) stay below the iOS budget. Also replace the per-frame preprocessing chain with the proven SmolVLM2-style chain (`CIImage.toSRGB().resampled(_:method:).normalized(_:_:)`). The previous chain used `MediaProcessing.apply` first, whose `bestFitScale` `transformed(by:)` leaves the CIImage with a non-square extent that round-trips fine through most paths but trips `asMLXArray`'s bitmap render on iOS with `verify_image_parameters: invalid image bits/pixel or bytes/row`. Both changes default-on; advanced users can override `video_frame_chunk_size` and `video_max_frames` in the config.
Covers the previously-uncovered public surface from ml-explore#256: - Gemma4ProcessorConfiguration decodes the new video_token_id / video_seq_length / video_frame_size / video_max_frames / video_fps fields, with iOS-safe defaults (videoSeqLength=64, videoMaxFrames=16, videoFps=2.0) applied when an older preprocessor_config.json omits them. - Gemma4MessageGenerator prefixes one '<|video|>' placeholder per video to the text block, leaving the message text alone when no videos are attached. 4 new tests, all passing.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Routes Gemma 4 video frames through the existing `vision_tower` as a batch (same encoder as images, just with a smaller per-frame soft-token budget) and performs text-level expansion of a `<|video|>` placeholder into a per-frame timestamped sequence — `mm:ss <start_of_image><|video|>×N<end_of_image>`, joined by spaces — matching the HF Gemma 4 reference processor.
Faithful to the Python reference in `mlx-vlm` Gemma 4 processor; the Swift-side simplification is a fixed-square video frame size (default 384×384) rather than aspect-preserving resize, so the existing pooler kernel grid (24/3 = 8) divides cleanly and per-frame token counts match the prompt placeholders exactly.
What's added
Why chunking + lower default frame count
The vision tower allocates an attention mask of shape `(batch, 1, max_patches, max_patches)`. With `max_patches = 2520` the mask alone is ~13 MB per frame, so a 32-frame batch peaks at ~400 MB before any layer activations — enough to OOM `gemma4-E2B` on iPhone alongside the language model. The chunked dispatch keeps peak memory roughly constant in frame count, and dropping the default to 16 frames adds headroom without sacrificing useful temporal coverage at `fps = 2`.
The `CIImage` chain change (`toSRGB().resampled().normalized()` instead of `MediaProcessing.apply` followed by `resampleBicubic`) avoids `MediaProcessing.apply`'s `bestFitScale` `transformed(by:)` step, which leaves the CIImage with a non-square extent that round-trips fine through most paths but trips `asMLXArray`'s bitmap render on iOS with `verify_image_parameters: invalid image bits/pixel or bytes/row`.
Token math
A 384×384 frame yields 24×24 = 576 patches. With Gemma 4's 3×3 pooling, kernel = √(576/64) = 3 (clean), output = 8×8 = 64 soft tokens per frame. `videoSeqLength` and `visionSoftTokensPerVideoFrame` both default to 64; override either via JSON if the model is published with a different budget — they must match.
Verification
Notes