Skip to content

Embedding gemma#7

Open
dmunch wants to merge 10 commits intoml-explore:mainfrom
dmunch:embedding_gemma
Open

Embedding gemma#7
dmunch wants to merge 10 commits intoml-explore:mainfrom
dmunch:embedding_gemma

Conversation

@dmunch
Copy link
Copy Markdown

@dmunch dmunch commented Nov 12, 2025

Proposed changes

Adding support for EmbeddingGemma to Embedders.

Basically cherry-picked all commits from ml-explore/mlx-swift-examples#398 from to the new repository, made sure everything compiles and ran swift-format.

Pretty new to MLX, and thought that would be a good learning opportunity and try to get something in. Let me know what kind of modifications etc. you'd still like to have to get this merged.

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

dmunch and others added 10 commits November 12, 2025 17:20
Cherry-Pick of 86bb1265168363cc5096b8df5f82075a5702ef2e
Co-authored-by: Tom Nickson <tnickson@apple.com>
Cherry-Pick of d44e2c3d6d5365655aa0e179432cf3548ecd17d4
Co-authored-by: Tom Nickson <tnickson@apple.com>
- Add useBidirectionalAttention config parameter
- Apply sliding window size adjustment for bidirectional mode
- Implement createBidirectionalSlidingWindowMask function
- Update mask creation logic to support both causal and bidirectional attention
- Based on patches 40694 and 40700 for EmbeddingGemma support

Cherry-Picked
Commit: 46be017e9f4b076f2d0842cf78175ac42d894b0a
Co-authored-by: Tom Nickson <tnickson@apple.com>
Cherry-Picked
Commit: 8dc179ccc21b26fb0856016ec9f2b7d5792979e0
Co-authored-by: Tom Nickson <tnickson@apple.com>
Commit: 733e142542cfaf85ca0304d37f908b176c54edfc
Co-authored-by: Tom Nickson <tnickson@apple.com>
Commit: 96ee882cd7c6fd3573b034686d3f3c5afe1ee04a
Co-authored-by: Tom Nickson <tnickson@apple.com>
// Copyright © 2024 Apple Inc.

import MLX
import MLXLMCommon
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this just to pick up the quantization? Think the Embedders should not require MLXLMCommon / MLXLLM if possible. A copy of Quantization is OK. If we end up with a lot of duplication then I think MLXLMCommon might make sense.

"gemma3_text": {
url in
let configuration = try JSONDecoder().decode(
Gemma3TextConfiguration.self, from: Data(contentsOf: url))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it makes sense to copy this config type into Embedders rather than add new linkage. Even sharing config types between models in the same library is rarely done.

Comment on lines +2 to +3
import MLXLLM
import MLXLMCommon
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See elsewhere -- I think this should be done without adding linkage to additional libraries.


@ModuleInfo private var model: Gemma3Model
@ModuleInfo(key: "lm_head") var lmHead: Linear
@ModuleInfo(key: "lm_head") var lmHead: Module // Can be Linear or QuantizedLinear
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QuantizedLinear is a subtype of Linear so this should have been OK as-is -- did you see a problem here?

@davidkoski
Copy link
Copy Markdown
Collaborator

@dmunch this looks good overall, but see my comments about not adding a new dependency on MLXLMCommon and MLXLLM.

Thanks!

Copy link
Copy Markdown
Collaborator

@davidkoski davidkoski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dmunch this looks good overall, but see my comments about not adding a new dependency on MLXLMCommon and MLXLLM.

@davidkoski davidkoski added the changes requested PR Feedback - changes requested label Feb 16, 2026
JaeminKim-amoz

This comment was marked as spam.

NivDvir added a commit to NivDvir/mlx-swift-lm that referenced this pull request Apr 19, 2026
…ox output

When porting Qwen2.5-VL grounding inference to native Swift via mlx-swift-lm,
the model consistently produced wrong bbox_2d coordinates (the "panels" it
detected were hallucinated by hundreds of pixels). After weeks of debugging,
found 9 distinct bugs. With all fixes applied, Swift output matches the Python
reference (mlx-vlm) at 0px delta on all 8 bbox edges.

Full writeup + diagnostics: https://dev.to/nivdvir/building-a-real-time-screen-reader-on-macos-that-actually-works-471

The 9 bugs:

1. MROPE section selection (split-select vs slice-replace)
   Multi-Resolution Rotary Position Embedding assigns different frequency
   bands to temporal (T), height (H), and width (W). The Swift impl split the
   frequency tensor via modulo indexing (i % 3), interleaving frequencies.
   Python starts with temporal freqs and overwrites H/W slices in place:
   [T_0-15, H_16-39, W_40-63]. Wrong layout destroys attention patterns.

2. invFreq registered as a Module weight
   invFreq is a computed constant, not a learned weight. Declaring it as a
   property on a Module subclass exposed it to MLX's weight loader, which
   either threw keyNotFound or silently overwrote with garbage. Fix: wrap in
   a non-Module class (InvFreqBox) to hide from reflection.

3. rope_deltas unused during autoregressive generation
   After the prefill pass, cached position IDs were cleared but rope_deltas
   were never applied to subsequent tokens. Correct computation:
   positionIds = cache_offset + rope_deltas + arange(seqLen). Without deltas,
   position embeddings drifted with each generated token.

4. MROPE state not reset between successive images
   Cached position IDs and rope deltas from one inference persisted into the
   next. Processing a new image meant position embeddings started from the
   previous image's offset. Progressively worse results on 2nd, 3rd images.

5. Image resize using 1800px max instead of 1280px
   Swift code resized to max 1800px (2688 visual tokens). Python reference
   uses 1280px max (1305 visual tokens). The model was trained on 1280px.
   1800px pushed visual token positions outside the training distribution.

6. Chat template ordering (text vs image token placement)
   Swift message generator placed text before the image token in the content
   array. Python puts image first: <|vision_start|><|image_pad|><|vision_end|>PROMPT.
   Ordering matters: text tokens attending to positions where image features
   have not yet been injected produces wrong attention patterns.

7. Vision attention mask ignored — THE ROOT CAUSE
   The vision encoder's self-attention uses a mask for windowed attention
   (each patch only attends to patches within its window). Swift passed
   mask: .none to scaledDotProductAttention instead of mask: .array(floatMask).
   Result: every patch attended globally to every other patch, destroying
   spatial locality the model relies on for precise coordinate prediction.
   This single bug was most responsible for bbox inaccuracy.

Tests against Python reference (mlx-vlm 0.1.31, Qwen2.5-VL-7B-Instruct-4bit):
- Before fixes: bounding boxes off by 200-800px, inconsistent across runs
- After fixes: 0px delta on all 8 bbox edges (x1,y1,x2,y2 for 2 panels)

Filed in context of upstream issue ml-explore#221 (Upstreaming improvements from
fast-moving forks). Happy to split into smaller PRs if maintainers prefer.

Related bugs ml-explore#6 (prompt format) and ml-explore#7 (maxTokens) live in consumer code,
not mlx-swift-lm, so they're out of scope for this PR.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

changes requested PR Feedback - changes requested

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants