Skip to content

Commit 45a501c

Browse files
Blaizzyclaude
andcommitted
qwen3_vl: match HF reference by fixing two upstream mlx-vlm bugs
On the 6-query × 6-image retrieval benchmark, the mlx-embeddings output had max|cosine diff| = 0.087 vs HF transformers reference and only 83% top-1 agreement. Three fixes close the gap to max 0.006 diff and 100% top-1/top-3 agreement: 1. Forward the embedder's MIN_PIXELS/MAX_PIXELS (4096..1,843,200) onto the inner image_processor. The Qwen3-VL preprocessor_config.json lists the full-context size bounds (16 MP), so without this override the image_processor resized to a different grid than the HF reference and the comparison ran on different visual tokens. 2. Work around mlx-vlm bug in Qwen3-VL get_input_embeddings: the upstream assigns `mx.eval(deepstack_image_embeds)` to `deepstack_visual_embeds`, but mx.eval returns None — so multi-scale deepstack features were silently dropped at every LM layer the model was supposed to inject them into. Re-run the vision tower in our Model.get_input_embeddings when we detect this. 3. Patch mlx-vlm's `_deepstack_process` on the language-model instance: upstream indexes the full concatenated visual_embeds at each batch sample's image positions, which only works for batch_size=1. Our patched version slices visual_embeds per sample using a running offset so multi-image batches work. Once (2) is fixed upstream, (3) surfaces immediately — they're stacked bugs that cancel for single-image batches. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 1bd3299 commit 45a501c

2 files changed

Lines changed: 74 additions & 0 deletions

File tree

mlx_embeddings/models/qwen3_vl/model.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Any, Dict, Optional, Tuple
33

44
import mlx.core as mx
5+
import numpy as np
56
from mlx_lm.models.base import create_causal_mask
67
from mlx_vlm.models.qwen3_vl import LanguageModel as Qwen3VLLanguageModel
78
from mlx_vlm.models.qwen3_vl import Model as Qwen3VLBackbone
@@ -11,6 +12,40 @@
1112
from ..base import BaseModelArgs, BaseModelOutput, normalize_embeddings
1213

1314

15+
def _patched_deepstack_process(
16+
self,
17+
hidden_states: mx.array,
18+
visual_pos_masks: mx.array,
19+
visual_embeds: mx.array,
20+
) -> mx.array:
21+
"""Fixed version of mlx-vlm's Qwen3-VL ``_deepstack_process``.
22+
23+
Upstream passes the full concatenated ``visual_embeds`` (all samples)
24+
into each sample's ``batch_result.at[batch_indices].add(...)``, which
25+
only broadcasts when the batch has one image. This version slices
26+
``visual_embeds`` per sample using the running offset of image-token
27+
counts so it works for multi-image batches.
28+
"""
29+
batch_size = hidden_states.shape[0]
30+
updated = []
31+
offset = 0
32+
for b in range(batch_size):
33+
batch_mask = visual_pos_masks[b]
34+
batch_hidden = hidden_states[b]
35+
batch_indices = mx.array(np.where(batch_mask)[0], dtype=mx.uint32)
36+
n = int(batch_indices.shape[0])
37+
if n == 0:
38+
updated.append(batch_hidden)
39+
continue
40+
batch_result = mx.array(batch_hidden)
41+
batch_result = batch_result.at[batch_indices].add(
42+
visual_embeds[offset : offset + n]
43+
)
44+
offset += n
45+
updated.append(batch_result)
46+
return mx.stack(updated, axis=0)
47+
48+
1449
def build_qwen3_vl_config(vlm_config: Dict[str, Any]) -> ModelConfig:
1550
base_config = dict(vlm_config)
1651
base_config["model_type"] = "qwen3_vl"
@@ -159,6 +194,14 @@ class Model(Qwen3VLBackbone):
159194
def __init__(self, config: ModelArgs):
160195
self.args = config
161196
super().__init__(build_qwen3_vl_config(config.vlm_config))
197+
# Fix upstream mlx-vlm Qwen3-VL bug (as of 0.4.4): _deepstack_process
198+
# indexes the full concatenated visual_embeds at each batch sample's
199+
# image positions, which is only correct for batch_size=1. Patch the
200+
# instance with a version that slices visual_embeds per sample.
201+
lm_inner = self.language_model.model
202+
lm_inner._deepstack_process = _patched_deepstack_process.__get__(
203+
lm_inner, type(lm_inner)
204+
)
162205

163206
@property
164207
def visual(self):
@@ -178,6 +221,29 @@ def get_video_features(
178221
) -> mx.array:
179222
return self.vision_tower(pixel_values, video_grid_thw)[0]
180223

224+
def get_input_embeddings(self, input_ids=None, pixel_values=None, **kwargs):
225+
# Work around an mlx-vlm bug (as of 0.4.4): Qwen3-VL's
226+
# get_input_embeddings assigns `mx.eval(deepstack_image_embeds)` to
227+
# `deepstack_visual_embeds`, but mx.eval returns None — so multi-scale
228+
# deepstack features are silently dropped, costing ~0.1 cosine on the
229+
# final image embedding. If they came back None but we actually have
230+
# images, re-run the vision tower just to grab the deepstack list.
231+
feats = super().get_input_embeddings(
232+
input_ids=input_ids, pixel_values=pixel_values, **kwargs
233+
)
234+
if (
235+
pixel_values is not None
236+
and feats.deepstack_visual_embeds is None
237+
and getattr(self.config.vision_config, "deepstack_visual_indexes", None)
238+
):
239+
image_grid_thw = kwargs.get("image_grid_thw")
240+
video_grid_thw = kwargs.get("video_grid_thw")
241+
grid_thw = image_grid_thw if image_grid_thw is not None else video_grid_thw
242+
dtype = self.vision_tower.patch_embed.proj.weight.dtype
243+
_, deepstack = self.vision_tower(pixel_values.astype(dtype), grid_thw)
244+
feats.deepstack_visual_embeds = deepstack
245+
return feats
246+
181247
def get_binary_logits(self, pooled: mx.array) -> mx.array:
182248
if hasattr(self.language_model, "lm_head"):
183249
token_logits = self.language_model.lm_head(pooled)

mlx_embeddings/models/qwen3_vl/processor.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -711,6 +711,14 @@ def from_pretrained(cls, model_path, **kwargs):
711711
kwargs.pop("use_fast", None)
712712

713713
processor = Qwen3VLProcessor.from_pretrained(model_path, **kwargs)
714+
# preprocessor_config.json often caps max_pixels at the full Qwen3-VL
715+
# context (e.g. 1,310,720 or 16M); the embedder-specific resize budget
716+
# (4096..1,843,200) must win so our resize matches the HF reference.
717+
processor.image_processor.min_pixels = min_pixels
718+
processor.image_processor.max_pixels = max_pixels
719+
if processor.video_processor is not None:
720+
processor.video_processor.min_pixels = min_pixels
721+
processor.video_processor.max_pixels = max_pixels
714722
return cls(
715723
processor=processor,
716724
embedding_max_length=embedding_max_length,

0 commit comments

Comments
 (0)