Embedding gemma#7
Conversation
Cherry-Pick of 86bb1265168363cc5096b8df5f82075a5702ef2e Co-authored-by: Tom Nickson <tnickson@apple.com>
- Add useBidirectionalAttention config parameter - Apply sliding window size adjustment for bidirectional mode - Implement createBidirectionalSlidingWindowMask function - Update mask creation logic to support both causal and bidirectional attention - Based on patches 40694 and 40700 for EmbeddingGemma support Cherry-Picked Commit: 46be017e9f4b076f2d0842cf78175ac42d894b0a Co-authored-by: Tom Nickson <tnickson@apple.com>
Cherry-Picked Commit: 8dc179ccc21b26fb0856016ec9f2b7d5792979e0 Co-authored-by: Tom Nickson <tnickson@apple.com>
Commit: 733e142542cfaf85ca0304d37f908b176c54edfc Co-authored-by: Tom Nickson <tnickson@apple.com>
Commit: 96ee882cd7c6fd3573b034686d3f3c5afe1ee04a Co-authored-by: Tom Nickson <tnickson@apple.com>
| // Copyright © 2024 Apple Inc. | ||
|
|
||
| import MLX | ||
| import MLXLMCommon |
There was a problem hiding this comment.
Is this just to pick up the quantization? Think the Embedders should not require MLXLMCommon / MLXLLM if possible. A copy of Quantization is OK. If we end up with a lot of duplication then I think MLXLMCommon might make sense.
| "gemma3_text": { | ||
| url in | ||
| let configuration = try JSONDecoder().decode( | ||
| Gemma3TextConfiguration.self, from: Data(contentsOf: url)) |
There was a problem hiding this comment.
I think it makes sense to copy this config type into Embedders rather than add new linkage. Even sharing config types between models in the same library is rarely done.
| import MLXLLM | ||
| import MLXLMCommon |
There was a problem hiding this comment.
See elsewhere -- I think this should be done without adding linkage to additional libraries.
|
|
||
| @ModuleInfo private var model: Gemma3Model | ||
| @ModuleInfo(key: "lm_head") var lmHead: Linear | ||
| @ModuleInfo(key: "lm_head") var lmHead: Module // Can be Linear or QuantizedLinear |
There was a problem hiding this comment.
QuantizedLinear is a subtype of Linear so this should have been OK as-is -- did you see a problem here?
|
@dmunch this looks good overall, but see my comments about not adding a new dependency on MLXLMCommon and MLXLLM. Thanks! |
davidkoski
left a comment
There was a problem hiding this comment.
@dmunch this looks good overall, but see my comments about not adding a new dependency on MLXLMCommon and MLXLLM.
…ox output When porting Qwen2.5-VL grounding inference to native Swift via mlx-swift-lm, the model consistently produced wrong bbox_2d coordinates (the "panels" it detected were hallucinated by hundreds of pixels). After weeks of debugging, found 9 distinct bugs. With all fixes applied, Swift output matches the Python reference (mlx-vlm) at 0px delta on all 8 bbox edges. Full writeup + diagnostics: https://dev.to/nivdvir/building-a-real-time-screen-reader-on-macos-that-actually-works-471 The 9 bugs: 1. MROPE section selection (split-select vs slice-replace) Multi-Resolution Rotary Position Embedding assigns different frequency bands to temporal (T), height (H), and width (W). The Swift impl split the frequency tensor via modulo indexing (i % 3), interleaving frequencies. Python starts with temporal freqs and overwrites H/W slices in place: [T_0-15, H_16-39, W_40-63]. Wrong layout destroys attention patterns. 2. invFreq registered as a Module weight invFreq is a computed constant, not a learned weight. Declaring it as a property on a Module subclass exposed it to MLX's weight loader, which either threw keyNotFound or silently overwrote with garbage. Fix: wrap in a non-Module class (InvFreqBox) to hide from reflection. 3. rope_deltas unused during autoregressive generation After the prefill pass, cached position IDs were cleared but rope_deltas were never applied to subsequent tokens. Correct computation: positionIds = cache_offset + rope_deltas + arange(seqLen). Without deltas, position embeddings drifted with each generated token. 4. MROPE state not reset between successive images Cached position IDs and rope deltas from one inference persisted into the next. Processing a new image meant position embeddings started from the previous image's offset. Progressively worse results on 2nd, 3rd images. 5. Image resize using 1800px max instead of 1280px Swift code resized to max 1800px (2688 visual tokens). Python reference uses 1280px max (1305 visual tokens). The model was trained on 1280px. 1800px pushed visual token positions outside the training distribution. 6. Chat template ordering (text vs image token placement) Swift message generator placed text before the image token in the content array. Python puts image first: <|vision_start|><|image_pad|><|vision_end|>PROMPT. Ordering matters: text tokens attending to positions where image features have not yet been injected produces wrong attention patterns. 7. Vision attention mask ignored — THE ROOT CAUSE The vision encoder's self-attention uses a mask for windowed attention (each patch only attends to patches within its window). Swift passed mask: .none to scaledDotProductAttention instead of mask: .array(floatMask). Result: every patch attended globally to every other patch, destroying spatial locality the model relies on for precise coordinate prediction. This single bug was most responsible for bbox inaccuracy. Tests against Python reference (mlx-vlm 0.1.31, Qwen2.5-VL-7B-Instruct-4bit): - Before fixes: bounding boxes off by 200-800px, inconsistent across runs - After fixes: 0px delta on all 8 bbox edges (x1,y1,x2,y2 for 2 panels) Filed in context of upstream issue ml-explore#221 (Upstreaming improvements from fast-moving forks). Happy to split into smaller PRs if maintainers prefer. Related bugs ml-explore#6 (prompt format) and ml-explore#7 (maxTokens) live in consumer code, not mlx-swift-lm, so they're out of scope for this PR.
Proposed changes
Adding support for
EmbeddingGemmato Embedders.Basically cherry-picked all commits from ml-explore/mlx-swift-examples#398 from to the new repository, made sure everything compiles and ran swift-format.
Pretty new to MLX, and thought that would be a good learning opportunity and try to get something in. Let me know what kind of modifications etc. you'd still like to have to get this merged.
Checklist
Put an
xin the boxes that apply.pre-commit run --all-filesto format my code / installed pre-commit prior to committing changes