From 192e94f936d5a858e3d85546d60391b3eef056f6 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sun, 19 Jan 2025 17:48:08 +0100 Subject: [PATCH 01/19] add token merging, filtering and update cache --- mlx_vlm/models/base.py | 128 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 126 insertions(+), 2 deletions(-) diff --git a/mlx_vlm/models/base.py b/mlx_vlm/models/base.py index ee47eb640..03d9994f8 100644 --- a/mlx_vlm/models/base.py +++ b/mlx_vlm/models/base.py @@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional import mlx.core as mx +import mlx.nn as nn from PIL import Image from transformers.image_processing_utils import BaseImageProcessor as ImageProcessor from transformers.image_processing_utils import get_size_dict @@ -97,6 +98,10 @@ def update(self, keys, values): self.keys[..., prev : self.offset, :] = keys self.values[..., prev : self.offset, :] = values + @property + def state(self): + return self.keys, self.values + class SimpleKVCache: """A simple key-value cache for transformer attention layers. @@ -148,7 +153,7 @@ def update(self, keys, values): class RotatingKVCache: - def __init__(self, head_dim, n_kv_heads, max_size, keep=0, step=256): + def __init__(self, head_dim, n_kv_heads, max_size, keep=None, step=256): self.n_kv_heads = n_kv_heads if isinstance(head_dim, int): self.k_head_dim = self.v_head_dim = head_dim @@ -156,7 +161,7 @@ def __init__(self, head_dim, n_kv_heads, max_size, keep=0, step=256): self.k_head_dim, self.v_head_dim = head_dim else: raise ValueError("head_dim must be an int or a tuple of two ints") - self.keep = keep + self.keep = keep if keep is not None else step // 2 self.keys = None self.values = None self.offset = 0 @@ -271,3 +276,122 @@ class LanguageModelOutput: logits: mx.array cross_attention_states: Optional[List[mx.array]] = None encoder_outputs: Optional[List[mx.array]] = None + + +class BaseModel(nn.Module): + def __init__(self): + super().__init__() + self.vision_tower = None + self.language_model = None + + def prefill(self, input_embeds, cache=None, prefill_step_size=256): + # Process input in batches for better parallelization + num_batches = ( + input_embeds.shape[1] + prefill_step_size - 1 + ) // prefill_step_size + + if num_batches > 1: + # Pre-allocate slices for better memory efficiency + slices = [ + input_embeds[:, i * prefill_step_size : (i + 1) * prefill_step_size, :] + for i in range(num_batches - 1) + ] + + # Process all full-sized batches in parallel + for slice in slices: + mask = create_attention_mask(slice, cache) + self.language_model(inputs_embeds=slice, cache=cache, mask=mask) + if cache is not None: + mx.eval([c.state for c in cache]) + mx.metal.clear_cache() + + # Return remaining slice + remaining_embeds = input_embeds[ + :, (num_batches - 1) * prefill_step_size :, : + ] + return remaining_embeds + + return input_embeds + + def get_topk_tokens(self, image_feature, attn, dominant_tokens_ratio=None): + batch_size, seq_len = image_feature.shape[:2] + + k_tokens = ( + int(image_feature.shape[1] * dominant_tokens_ratio) + if dominant_tokens_ratio is not None + else None + ) # keep 25% of the visual tokens + if k_tokens is None: + return image_feature + cls_idx = 0 # self.config.image_token_index + + attn_rec = mx.sum(attn[:, :, cls_idx + 1 :, cls_idx], axis=1) + + topk_idx = mx.argsort(attn_rec, axis=1)[:, -k_tokens:] + # use this to plot the dominant attention map + # https://github.com/dvlab-research/VisionZip/blob/demo-chat/llava/model/multimodal_encoder/clip_encoder.py#L62 + # https://github.com/dvlab-research/VisionZip/blob/demo-chat/llava/serve/gradio_web_server.py#L424 + + # Create CLS token indices array + # Shape: (B, 1) + cls_indices = mx.full((batch_size, 1), cls_idx, dtype=mx.int32) + + # Concat with CLS token index + # Add 1 to account for the offset after CLS token + dominant_idx = mx.concatenate([cls_indices, topk_idx + cls_idx + 1], axis=1) + + image_feature = mx.take(image_feature, dominant_idx, axis=1)[0] + return image_feature + + def merge_similar_visual_tokens( + self, image_feature, visual_token_ratio, merge_ratio=0.4 + ): + # Skip CLS token (first token) + tokens = image_feature[:, 1:] + batch_size, num_tokens, hidden_dim = tokens.shape + + # Calculate target number of tokens + target_tokens = max(1, int(num_tokens * visual_token_ratio)) + + while num_tokens > target_tokens: + # Calculate similarities between adjacent tokens + tokens_a = tokens[:, :-1] # all except last + tokens_b = tokens[:, 1:] # all except first + + # Calculate cosine similarity + a_norm = mx.sqrt(mx.sum(tokens_a * tokens_a, axis=-1, keepdims=True)) + b_norm = mx.sqrt(mx.sum(tokens_b * tokens_b, axis=-1, keepdims=True)) + similarities = mx.sum(tokens_a * tokens_b, axis=-1) + similarities = similarities / (a_norm.squeeze(-1) * b_norm.squeeze(-1)) + + # Sort similarities and get indices of pairs to merge + # We'll merge about 50% of remaining excess tokens in each iteration + num_to_merge = max(1, int((num_tokens - target_tokens) * merge_ratio)) + merge_indices = mx.argsort(similarities, axis=-1)[:, -num_to_merge:] + + # Create a list to track which indices to merge + to_merge = set(merge_indices[0].tolist()) + + # Merge selected pairs + new_tokens = [] + i = 0 + while i < num_tokens: + if i < num_tokens - 1 and i in to_merge: + # Merge this token with the next one + merged = (tokens[:, i : i + 1] + tokens[:, i + 1 : i + 2]) / 2 + new_tokens.append(merged) + i += 2 + elif i > 0 and (i - 1) in to_merge: + # Skip this token as it was merged in the previous step + i += 1 + else: + # Keep this token as is + new_tokens.append(tokens[:, i : i + 1]) + i += 1 + + # Update tokens + tokens = mx.concatenate(new_tokens, axis=1) + num_tokens = tokens.shape[1] + + # Reattach CLS token + return mx.concatenate([image_feature[:, :1], tokens], axis=1) From aa830d11c90dc0451ec51b50f0fa5dfd57f81dcd Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sun, 19 Jan 2025 17:58:15 +0100 Subject: [PATCH 02/19] add merge and filter tokens --- mlx_vlm/generate.py | 36 ++++++++++++++++++++++++++++-------- 1 file changed, 28 insertions(+), 8 deletions(-) diff --git a/mlx_vlm/generate.py b/mlx_vlm/generate.py index 1527c2b68..886e6cb98 100644 --- a/mlx_vlm/generate.py +++ b/mlx_vlm/generate.py @@ -69,10 +69,26 @@ def parse_arguments(): help="Maximum number of tokens to generate.", ) parser.add_argument( - "--temp", type=float, default=DEFAULT_TEMP, help="Temperature for sampling." + "--temperature", + type=float, + default=DEFAULT_TEMP, + help="Temperature for sampling.", ) parser.add_argument("--chat", action="store_true", help="Chat in multi-turn style.") parser.add_argument("--verbose", action="store_false", help="Detailed output.") + parser.add_argument( + "--vision-merge-ratio", + type=float, + default=1.0, + help="Ratio of vision tokens to keep during merging similar tokens (between 0.1 and 1.0).", + choices=[x / 10 for x in range(1, 11)], + ) + parser.add_argument( + "--vision-filter-ratio", + type=float, + help="Ratio of vision tokens to keep during filtering topk tokens (between 0.1 and 1.0).", + choices=[x / 10 for x in range(1, 11)], + ) return parser.parse_args() @@ -97,16 +113,20 @@ def main(): prompt = apply_chat_template(processor, config, prompt, num_images=len(args.image)) kwargs = {} + if args.max_kv_size is not None: + kwargs["max_kv_size"] = args.max_kv_size if args.resize_shape is not None: - resize_shape = args.resize_shape - if len(resize_shape) not in [1, 2]: + if len(args.resize_shape) not in [1, 2]: raise ValueError("Resize shape must be 1 or 2 integers") kwargs["resize_shape"] = ( - (resize_shape[0], resize_shape[0]) - if len(resize_shape) == 1 - else resize_shape + (args.resize_shape[0],) * 2 + if len(args.resize_shape) == 1 + else tuple(args.resize_shape) ) + kwargs["vision_merge_ratio"] = args.vision_merge_ratio + kwargs["vision_filter_ratio"] = args.vision_filter_ratio + if args.chat: chat = [] if args.system: @@ -124,7 +144,7 @@ def main(): prompt, args.image, max_tokens=args.max_tokens, - temp=args.temp, + temperature=args.temperature, **kwargs, ): response += chunk.text @@ -139,7 +159,7 @@ def main(): processor, prompt, image=args.image, - temp=args.temp, + temperature=args.temperature, max_tokens=args.max_tokens, verbose=args.verbose, **kwargs, From b3501a2ed5533e33f4877c62062d8de9ef5df829 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sun, 19 Jan 2025 17:59:31 +0100 Subject: [PATCH 03/19] remove prefill --- mlx_vlm/models/base.py | 29 ----------------------------- 1 file changed, 29 deletions(-) diff --git a/mlx_vlm/models/base.py b/mlx_vlm/models/base.py index 03d9994f8..185f389c3 100644 --- a/mlx_vlm/models/base.py +++ b/mlx_vlm/models/base.py @@ -284,35 +284,6 @@ def __init__(self): self.vision_tower = None self.language_model = None - def prefill(self, input_embeds, cache=None, prefill_step_size=256): - # Process input in batches for better parallelization - num_batches = ( - input_embeds.shape[1] + prefill_step_size - 1 - ) // prefill_step_size - - if num_batches > 1: - # Pre-allocate slices for better memory efficiency - slices = [ - input_embeds[:, i * prefill_step_size : (i + 1) * prefill_step_size, :] - for i in range(num_batches - 1) - ] - - # Process all full-sized batches in parallel - for slice in slices: - mask = create_attention_mask(slice, cache) - self.language_model(inputs_embeds=slice, cache=cache, mask=mask) - if cache is not None: - mx.eval([c.state for c in cache]) - mx.metal.clear_cache() - - # Return remaining slice - remaining_embeds = input_embeds[ - :, (num_batches - 1) * prefill_step_size :, : - ] - return remaining_embeds - - return input_embeds - def get_topk_tokens(self, image_feature, attn, dominant_tokens_ratio=None): batch_size, seq_len = image_feature.shape[:2] From 85e3ae8ffd7ada44ae36d1a5937c86e1c07fb601 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sun, 19 Jan 2025 18:04:08 +0100 Subject: [PATCH 04/19] add merging and filtering --- mlx_vlm/utils.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/mlx_vlm/utils.py b/mlx_vlm/utils.py index e1d0c37d8..d9dbf8a8b 100644 --- a/mlx_vlm/utils.py +++ b/mlx_vlm/utils.py @@ -1096,6 +1096,9 @@ def generate( print("Image:", image, "\n") print("Prompt:", prompt) + vision_merge_ratio = kwargs.get("vision_merge_ratio", 1.0) + vision_filter_ratio = kwargs.get("vision_filter_ratio", 1.0) + text = "" last_response = None for response in stream_generate(model, processor, prompt, image, **kwargs): @@ -1109,8 +1112,12 @@ def generate( if len(text) == 0: print("No text generated for this prompt") return + + total_tokens = ( + last_response.prompt_tokens * vision_merge_ratio + ) * vision_filter_ratio print( - f"Prompt: {last_response.prompt_tokens} tokens, " + f"Prompt: {int(total_tokens)} tokens, " f"{last_response.prompt_tps:.3f} tokens-per-sec" ) print( From d2c940aba542fda84e4f4e8451fba07433d8739c Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sun, 19 Jan 2025 18:04:19 +0100 Subject: [PATCH 05/19] add return attn --- mlx_vlm/models/llava_next/vision.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/mlx_vlm/models/llava_next/vision.py b/mlx_vlm/models/llava_next/vision.py index 5a5ec42e0..16e74622a 100644 --- a/mlx_vlm/models/llava_next/vision.py +++ b/mlx_vlm/models/llava_next/vision.py @@ -96,12 +96,12 @@ def __call__(self, queries, keys, values, mask=None): keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) - output = mx.fast.scaled_dot_product_attention( + attn = mx.fast.scaled_dot_product_attention( queries, keys, values, scale=self.scale, mask=mask ) - output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + output = attn.transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.out_proj(output) + return self.out_proj(output), attn class MLP(nn.Module): @@ -130,11 +130,11 @@ def __init__(self, config: VisionConfig): def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array: y = self.layer_norm1(x) - y = self.self_attn(y, y, y, mask) + y, attn = self.self_attn(y, y, y, mask) x = x + y y = self.layer_norm2(x) y = self.mlp(y) - return x + y + return x + y, attn class Encoder(nn.Module): @@ -197,14 +197,16 @@ def __call__( x = self.pre_layrnorm(x) encoder_states = (x,) if output_hidden_states else None + attns = tuple() for l in self.encoder.layers: - x = l(x, mask=None) + x, attn = l(x, mask=None) if output_hidden_states: encoder_states = encoder_states + (x,) + attns = attns + (attn,) pooler_output = self.post_layernorm(x[:, 0, :]) - return pooler_output, x, encoder_states + return pooler_output, x, encoder_states, attns class VisionModel(nn.Module): From db866fb0837624783c6160803dfecbfa93d6e147 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sun, 19 Jan 2025 18:13:28 +0100 Subject: [PATCH 06/19] remove max_kv_size --- mlx_vlm/generate.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/mlx_vlm/generate.py b/mlx_vlm/generate.py index 886e6cb98..33fa7f44b 100644 --- a/mlx_vlm/generate.py +++ b/mlx_vlm/generate.py @@ -113,8 +113,6 @@ def main(): prompt = apply_chat_template(processor, config, prompt, num_images=len(args.image)) kwargs = {} - if args.max_kv_size is not None: - kwargs["max_kv_size"] = args.max_kv_size if args.resize_shape is not None: if len(args.resize_shape) not in [1, 2]: raise ValueError("Resize shape must be 1 or 2 integers") From 33c79a1a05f2d4eebbec4fb381e4b0c72420b80c Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sun, 19 Jan 2025 18:24:47 +0100 Subject: [PATCH 07/19] add attn and toke filter and merge --- mlx_vlm/models/base.py | 14 +++++++------- mlx_vlm/models/llava/vision.py | 16 ++++++++++------ mlx_vlm/models/llava_bunny/vision.py | 16 ++++++++++------ mlx_vlm/models/llava_next/vision.py | 6 ++++-- mlx_vlm/models/qwen2_vl/qwen2_vl.py | 21 +++++++++++++++++---- mlx_vlm/models/qwen2_vl/vision.py | 19 ++++++++++++------- 6 files changed, 60 insertions(+), 32 deletions(-) diff --git a/mlx_vlm/models/base.py b/mlx_vlm/models/base.py index 185f389c3..6e25605e1 100644 --- a/mlx_vlm/models/base.py +++ b/mlx_vlm/models/base.py @@ -284,12 +284,12 @@ def __init__(self): self.vision_tower = None self.language_model = None - def get_topk_tokens(self, image_feature, attn, dominant_tokens_ratio=None): + def filter_topk_vision_tokens(self, image_feature, attn, vision_filter_ratio=None): batch_size, seq_len = image_feature.shape[:2] k_tokens = ( - int(image_feature.shape[1] * dominant_tokens_ratio) - if dominant_tokens_ratio is not None + int(image_feature.shape[1] * vision_filter_ratio) + if vision_filter_ratio is not None else None ) # keep 25% of the visual tokens if k_tokens is None: @@ -314,15 +314,15 @@ def get_topk_tokens(self, image_feature, attn, dominant_tokens_ratio=None): image_feature = mx.take(image_feature, dominant_idx, axis=1)[0] return image_feature - def merge_similar_visual_tokens( - self, image_feature, visual_token_ratio, merge_ratio=0.4 + def merge_similar_vision_tokens( + self, image_feature, vision_merge_ratio, merge_rate=0.4 ): # Skip CLS token (first token) tokens = image_feature[:, 1:] batch_size, num_tokens, hidden_dim = tokens.shape # Calculate target number of tokens - target_tokens = max(1, int(num_tokens * visual_token_ratio)) + target_tokens = max(1, int(num_tokens * vision_merge_ratio)) while num_tokens > target_tokens: # Calculate similarities between adjacent tokens @@ -337,7 +337,7 @@ def merge_similar_visual_tokens( # Sort similarities and get indices of pairs to merge # We'll merge about 50% of remaining excess tokens in each iteration - num_to_merge = max(1, int((num_tokens - target_tokens) * merge_ratio)) + num_to_merge = max(1, int((num_tokens - target_tokens) * merge_rate)) merge_indices = mx.argsort(similarities, axis=-1)[:, -num_to_merge:] # Create a list to track which indices to merge diff --git a/mlx_vlm/models/llava/vision.py b/mlx_vlm/models/llava/vision.py index 31c273400..c0efe667d 100644 --- a/mlx_vlm/models/llava/vision.py +++ b/mlx_vlm/models/llava/vision.py @@ -96,12 +96,12 @@ def __call__(self, queries, keys, values, mask=None): keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) - output = mx.fast.scaled_dot_product_attention( + attn = mx.fast.scaled_dot_product_attention( queries, keys, values, scale=self.scale, mask=mask ) - output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + output = attn.transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.out_proj(output) + return self.out_proj(output), attn class MLP(nn.Module): @@ -130,11 +130,11 @@ def __init__(self, config: VisionConfig): def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array: y = self.layer_norm1(x) - y = self.self_attn(y, y, y, mask) + y, attn = self.self_attn(y, y, y, mask) x = x + y y = self.layer_norm2(x) y = self.mlp(y) - return x + y + return x + y, attn class Encoder(nn.Module): @@ -207,17 +207,21 @@ def __call__( self, x: mx.array, output_hidden_states: Optional[bool] = None, + output_attn: Optional[bool] = None, ) -> mx.array: x = self.embeddings(x) if self.config.model_type == "clip_vision_model": x = self.pre_layrnorm(x) encoder_states = (x,) if output_hidden_states else None + all_attns = () if output_attn else None for l in self.encoder.layers: - x = l(x, mask=None) + x, attn = l(x, mask=None) if output_hidden_states: encoder_states = encoder_states + (x,) + if output_attn: + all_attns = all_attns + (attn,) pooler_output = self.post_layernorm(x[:, 0, :]) return pooler_output, x, encoder_states diff --git a/mlx_vlm/models/llava_bunny/vision.py b/mlx_vlm/models/llava_bunny/vision.py index df3e3c579..294b1c88b 100644 --- a/mlx_vlm/models/llava_bunny/vision.py +++ b/mlx_vlm/models/llava_bunny/vision.py @@ -95,11 +95,11 @@ def __call__(self, queries, keys, values, mask=None): keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) - output = mx.fast.scaled_dot_product_attention( + attn = mx.fast.scaled_dot_product_attention( queries, keys, values, scale=self.scale, mask=mask ) - output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.out_proj(output) + output = attn.transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.out_proj(output), attn class MHA(nn.Module): @@ -170,11 +170,11 @@ def __init__(self, config: VisionConfig): def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array: y = self.layer_norm1(x) - y = self.self_attn(y, y, y, mask) + y, attn = self.self_attn(y, y, y, mask) x = x + y y = self.layer_norm2(x) y = self.mlp(y) - return x + y + return x + y, attn class Encoder(nn.Module): @@ -225,15 +225,19 @@ def __call__( self, x: mx.array, output_hidden_states: Optional[bool] = None, + output_attn: Optional[bool] = None, ) -> mx.array: x = self.embeddings(x) encoder_states = (x,) if output_hidden_states else None + all_attns = () if output_attn else None for l in self.encoder.layers: - x = l(x, mask=None) + x, attn = l(x, mask=None) if output_hidden_states: encoder_states = encoder_states + (x,) + if output_attn: + all_attns = all_attns + (attn,) pooler_output = self.post_layernorm(x[:, 0, :]) pooler_output = self.head(pooler_output) diff --git a/mlx_vlm/models/llava_next/vision.py b/mlx_vlm/models/llava_next/vision.py index 16e74622a..5992c97c3 100644 --- a/mlx_vlm/models/llava_next/vision.py +++ b/mlx_vlm/models/llava_next/vision.py @@ -192,18 +192,20 @@ def __call__( self, x: mx.array, output_hidden_states: Optional[bool] = None, + output_attn: bool = False, ) -> mx.array: x = self.embeddings(x) x = self.pre_layrnorm(x) encoder_states = (x,) if output_hidden_states else None - attns = tuple() + all_attns = () if output_attn else None for l in self.encoder.layers: x, attn = l(x, mask=None) if output_hidden_states: encoder_states = encoder_states + (x,) - attns = attns + (attn,) + if output_attn: + all_attns = all_attns + (attn,) pooler_output = self.post_layernorm(x[:, 0, :]) return pooler_output, x, encoder_states, attns diff --git a/mlx_vlm/models/qwen2_vl/qwen2_vl.py b/mlx_vlm/models/qwen2_vl/qwen2_vl.py index bc907628e..58de6bf09 100644 --- a/mlx_vlm/models/qwen2_vl/qwen2_vl.py +++ b/mlx_vlm/models/qwen2_vl/qwen2_vl.py @@ -10,6 +10,7 @@ import numpy as np from huggingface_hub import snapshot_download +from ..base import BaseModel from .language import LanguageModel, TextConfig from .vision import VisionConfig, VisionModel @@ -42,7 +43,7 @@ def from_dict(cls, params): ) -class Model(nn.Module): +class Model(BaseModel): def __init__(self, config: ModelConfig): super().__init__() self.config = config @@ -54,6 +55,7 @@ def get_input_embeddings( input_ids: Optional[mx.array] = None, pixel_values: Optional[mx.array] = None, image_grid_thw: Optional[mx.array] = None, + **kwargs, ): if pixel_values is None: @@ -66,13 +68,24 @@ def get_input_embeddings( inputs_embeds = self.language_model.model.embed_tokens(input_ids) # Get the ouptut hidden states from the vision model - hidden_states = self.vision_tower( - pixel_values, image_grid_thw, output_hidden_states=False + hidden_states, all_attns = self.vision_tower( + pixel_values, image_grid_thw, output_hidden_states=False, output_attn=True ) if hidden_states.ndim == 2: hidden_states = hidden_states[None, :, :] + if all_attns: + attn = all_attns[-1] + vision_filter_ratio = kwargs.get("vision_filter_ratio", 1.0) + vision_merge_ratio = kwargs.get("vision_merge_ratio", 1.0) + hidden_states = self.filter_topk_vision_tokens( + hidden_states, attn, vision_filter_ratio + ) + hidden_states = self.merge_similar_vision_tokens( + hidden_states, vision_merge_ratio + ) + # Insert special image tokens in the input_ids final_inputs_embeds = self._merge_input_ids_with_image_features( hidden_states, inputs_embeds, input_ids @@ -104,7 +117,7 @@ def __call__( image_grid_thw = mx.array(image_grid_thw) input_embddings = self.get_input_embeddings( - input_ids, pixel_values, image_grid_thw + input_ids, pixel_values, image_grid_thw, **kwargs ) logits = self.language_model(None, cache=cache, inputs_embeds=input_embddings) diff --git a/mlx_vlm/models/qwen2_vl/vision.py b/mlx_vlm/models/qwen2_vl/vision.py index bd32514db..b5a26a0b0 100644 --- a/mlx_vlm/models/qwen2_vl/vision.py +++ b/mlx_vlm/models/qwen2_vl/vision.py @@ -179,12 +179,12 @@ def __call__( k = k.transpose(0, 2, 1, 3) v = v.transpose(0, 2, 1, 3) - output = mx.fast.scaled_dot_product_attention( + attn = mx.fast.scaled_dot_product_attention( q, k, v, scale=self.scale, mask=attention_mask ) - output = output.transpose(0, 2, 1, 3) + output = attn.transpose(0, 2, 1, 3) output = output.reshape(seq_length, -1) - return self.proj(output) + return self.proj(output), attn class MLP(nn.Module): @@ -211,13 +211,14 @@ def __init__(self, config: VisionConfig) -> None: self.mlp = MLP(dim=config.embed_dim, hidden_dim=mlp_hidden_dim) def __call__(self, hidden_states, cu_seqlens, rotary_pos_emb) -> mx.array: - hidden_states = hidden_states + self.attn( + x, attn = self.attn( self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, ) + hidden_states += x hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) - return hidden_states + return hidden_states, attn class VisionModel(nn.Module): @@ -286,6 +287,7 @@ def __call__( hidden_states: mx.array, grid_thw: mx.array, output_hidden_states: Optional[bool] = None, + output_attn: Optional[bool] = None, ) -> mx.array: hidden_states = self.patch_embed(hidden_states) @@ -307,15 +309,18 @@ def __call__( cu_seqlens = mx.pad(cu_seqlens, (1, 0), mode="constant", constant_values=0) encoder_states = (hidden_states,) if output_hidden_states else None + all_attns = () if output_attn else None for blk in self.blocks: - hidden_states = blk( + hidden_states, attn = blk( hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb ) if output_hidden_states: encoder_states = encoder_states + (hidden_states,) + if output_attn: + all_attns = all_attns + (attn,) - return self.merger(hidden_states) + return self.merger(hidden_states), all_attns def sanitize(self, weights): sanitized_weights = {} From 404887179784b568f524ab22b571e5ace760f28c Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sun, 19 Jan 2025 19:47:06 +0100 Subject: [PATCH 08/19] fix token filtering and merging --- mlx_vlm/models/base.py | 8 +++----- mlx_vlm/models/pixtral/pixtral.py | 20 +++++++++++++++++--- mlx_vlm/models/pixtral/vision.py | 18 +++++++++++------- mlx_vlm/models/qwen2_vl/qwen2_vl.py | 20 ++++++++++++++++---- mlx_vlm/utils.py | 1 - 5 files changed, 47 insertions(+), 20 deletions(-) diff --git a/mlx_vlm/models/base.py b/mlx_vlm/models/base.py index 6e25605e1..ddb7307e0 100644 --- a/mlx_vlm/models/base.py +++ b/mlx_vlm/models/base.py @@ -286,22 +286,20 @@ def __init__(self): def filter_topk_vision_tokens(self, image_feature, attn, vision_filter_ratio=None): batch_size, seq_len = image_feature.shape[:2] - k_tokens = ( int(image_feature.shape[1] * vision_filter_ratio) if vision_filter_ratio is not None else None ) # keep 25% of the visual tokens - if k_tokens is None: + + if k_tokens is None or k_tokens == seq_len: return image_feature + cls_idx = 0 # self.config.image_token_index attn_rec = mx.sum(attn[:, :, cls_idx + 1 :, cls_idx], axis=1) topk_idx = mx.argsort(attn_rec, axis=1)[:, -k_tokens:] - # use this to plot the dominant attention map - # https://github.com/dvlab-research/VisionZip/blob/demo-chat/llava/model/multimodal_encoder/clip_encoder.py#L62 - # https://github.com/dvlab-research/VisionZip/blob/demo-chat/llava/serve/gradio_web_server.py#L424 # Create CLS token indices array # Shape: (B, 1) diff --git a/mlx_vlm/models/pixtral/pixtral.py b/mlx_vlm/models/pixtral/pixtral.py index 37dbc3cbb..1f926e625 100644 --- a/mlx_vlm/models/pixtral/pixtral.py +++ b/mlx_vlm/models/pixtral/pixtral.py @@ -68,6 +68,7 @@ def get_input_embeddings( self, input_ids: Optional[mx.array] = None, pixel_values: Optional[mx.array] = None, + **kwargs, ): if pixel_values is None: return self.language_model.model.embed_tokens(input_ids) @@ -96,12 +97,25 @@ def get_input_embeddings( # Pass pixel_values as list of images, as each image is individually run through conv2d and position encoding # Reference code from transformers: https://github.com/huggingface/transformers/blob/main/src/transformers/models/pixtral/modeling_pixtral.py#L479C9-L479C21 # and mistral_inference: https://github.com/mistralai/mistral-inference/blob/main/src/mistral_inference/vision_encoder.py#L85 - *_, hidden_states = self.vision_tower( - [pv.transpose(0, 2, 3, 1) for pv in pixel_values], output_hidden_states=True + *_, hidden_states, all_attns = self.vision_tower( + [pv.transpose(0, 2, 3, 1) for pv in pixel_values], + output_hidden_states=True, + output_attentions=True, ) # Select the hidden states from the desired layer selected_image_feature = hidden_states[self.vision_feature_layer] + if all_attns: + attn = all_attns[self.vision_feature_layer] + vision_filter_ratio = kwargs.get("vision_filter_ratio", 1.0) + vision_merge_ratio = kwargs.get("vision_merge_ratio", 1.0) + selected_image_feature = self.filter_topk_vision_tokens( + selected_image_feature, attn, vision_filter_ratio + ) + selected_image_feature = self.merge_similar_vision_tokens( + selected_image_feature, vision_merge_ratio + ) + # Pass image features through the multi-modal projector image_features = self.multi_modal_projector(selected_image_feature) @@ -144,7 +158,7 @@ def __call__( cache=None, **kwargs, ): - input_embddings = self.get_input_embeddings(input_ids, pixel_values) + input_embddings = self.get_input_embeddings(input_ids, pixel_values, **kwargs) logits = self.language_model( input_ids, cache=cache, inputs_embeds=input_embddings ) diff --git a/mlx_vlm/models/pixtral/vision.py b/mlx_vlm/models/pixtral/vision.py index 1f015ba0f..20387810e 100644 --- a/mlx_vlm/models/pixtral/vision.py +++ b/mlx_vlm/models/pixtral/vision.py @@ -153,11 +153,11 @@ def __call__(self, queries, keys, values, position_embeddings, mask=None): attn_weights = attn_weights + mask attn_weights = mx.softmax(attn_weights, axis=-1) - output = mx.matmul(attn_weights, values) + attn = mx.matmul(attn_weights, values) - output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + output = attn.transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.o_proj(output) + return self.o_proj(output), attn class MLP(nn.Module): @@ -191,11 +191,11 @@ def __call__( mask: Optional[mx.array] = None, ) -> mx.array: y = self.attention_norm(x) - y = self.attention(y, y, y, position_embeddings, mask) + y, attn = self.attention(y, y, y, position_embeddings, mask) x = x + y y = self.ffn_norm(x) y = self.feed_forward(y) - return x + y + return x + y, attn class Encoder(nn.Module): @@ -255,6 +255,7 @@ def __call__( self, x: List[mx.array], output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, ) -> mx.array: B, H, W, C = x[0].shape patch_embeds_list = [self.patch_conv(img) for img in x] @@ -277,15 +278,18 @@ def __call__( ) encoder_states = (patch_embeds,) if output_hidden_states else None + all_attns = () if output_attentions else None for l in self.transformer.layers: - patch_embeds = l( + patch_embeds, attn = l( patch_embeds, mask=mask, position_embeddings=position_embedding ) if output_hidden_states: encoder_states = encoder_states + (patch_embeds,) + if output_attentions: + all_attns = all_attns + (attn,) - return patch_embeds, encoder_states + return patch_embeds, encoder_states, all_attns class VisionModel(nn.Module): diff --git a/mlx_vlm/models/qwen2_vl/qwen2_vl.py b/mlx_vlm/models/qwen2_vl/qwen2_vl.py index 58de6bf09..6b79eb1ef 100644 --- a/mlx_vlm/models/qwen2_vl/qwen2_vl.py +++ b/mlx_vlm/models/qwen2_vl/qwen2_vl.py @@ -96,13 +96,25 @@ def _merge_input_ids_with_image_features( self, image_features, inputs_embeds, input_ids ): image_token_index = self.config.image_token_index + num_images, num_image_patches, embed_dim = image_features.shape # Positions of tokens in input_ids, assuming batch size is 1 - image_positions = input_ids == image_token_index - image_indices = np.where(image_positions)[1].tolist() - inputs_embeds[:, image_indices, :] = image_features + image_positions = np.where(input_ids[0] == image_token_index)[0].tolist() - return inputs_embeds + text_segments = [] + start_idx = 0 + + for position in image_positions: + text_segments.append(inputs_embeds[:, start_idx:position]) + start_idx = position + 1 + + image_embeddings = mx.split(image_features, image_features.shape[0]) + final_embeddings = [v for p in zip(text_segments, image_embeddings) for v in p] + final_embeddings += [inputs_embeds[:, start_idx:]] + + # Create a final embedding of shape + # (1, num_image_patches*num_images + sequence_len, embed_dim) + return mx.concatenate(final_embeddings, axis=1) def __call__( self, diff --git a/mlx_vlm/utils.py b/mlx_vlm/utils.py index d9dbf8a8b..3f0ab0559 100644 --- a/mlx_vlm/utils.py +++ b/mlx_vlm/utils.py @@ -1112,7 +1112,6 @@ def generate( if len(text) == 0: print("No text generated for this prompt") return - total_tokens = ( last_response.prompt_tokens * vision_merge_ratio ) * vision_filter_ratio From 42c5f572f4ecc464ffb7a4308ca3a93a6f49a86f Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sun, 19 Jan 2025 19:50:25 +0100 Subject: [PATCH 09/19] add baseModel --- mlx_vlm/models/pixtral/pixtral.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mlx_vlm/models/pixtral/pixtral.py b/mlx_vlm/models/pixtral/pixtral.py index 1f926e625..37e5301df 100644 --- a/mlx_vlm/models/pixtral/pixtral.py +++ b/mlx_vlm/models/pixtral/pixtral.py @@ -10,6 +10,7 @@ import numpy as np from huggingface_hub import snapshot_download +from ..base import BaseModel from .language import LanguageModel, TextConfig from .vision import VisionConfig, VisionModel @@ -54,7 +55,7 @@ def __call__(self, x: mx.array) -> mx.array: return x -class Model(nn.Module): +class Model(BaseModel): def __init__(self, config: ModelConfig): super().__init__() self.config = config From b8491ad08e6a6ac93a12a96da8d445242b2e83a8 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sun, 19 Jan 2025 20:05:12 +0100 Subject: [PATCH 10/19] fix attentions and NoneType --- mlx_vlm/generate.py | 1 + mlx_vlm/models/pixtral/vision.py | 7 +++++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/mlx_vlm/generate.py b/mlx_vlm/generate.py index 33fa7f44b..3a17f8544 100644 --- a/mlx_vlm/generate.py +++ b/mlx_vlm/generate.py @@ -86,6 +86,7 @@ def parse_arguments(): parser.add_argument( "--vision-filter-ratio", type=float, + default=1.0, help="Ratio of vision tokens to keep during filtering topk tokens (between 0.1 and 1.0).", choices=[x / 10 for x in range(1, 11)], ) diff --git a/mlx_vlm/models/pixtral/vision.py b/mlx_vlm/models/pixtral/vision.py index 20387810e..11b6c8d31 100644 --- a/mlx_vlm/models/pixtral/vision.py +++ b/mlx_vlm/models/pixtral/vision.py @@ -303,9 +303,12 @@ def __init__(self, config: VisionConfig): self.vision_model = PixtralVisionModel(config) def __call__( - self, x: List[mx.array], output_hidden_states: Optional[bool] = None + self, + x: List[mx.array], + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, ) -> mx.array: - return self.vision_model(x, output_hidden_states) + return self.vision_model(x, output_hidden_states, output_attentions) def sanitize(self, weights): sanitized_weights = {} From 4dad82839f563ac4dfa4ccf95c749908a4afe49b Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sun, 19 Jan 2025 20:29:43 +0100 Subject: [PATCH 11/19] add filtering and merging to paligemma --- mlx_vlm/models/paligemma/paligemma.py | 21 ++++++++++++++--- mlx_vlm/models/paligemma/vision.py | 34 +++++++++++++++++---------- 2 files changed, 39 insertions(+), 16 deletions(-) diff --git a/mlx_vlm/models/paligemma/paligemma.py b/mlx_vlm/models/paligemma/paligemma.py index b8179b935..8223aa286 100644 --- a/mlx_vlm/models/paligemma/paligemma.py +++ b/mlx_vlm/models/paligemma/paligemma.py @@ -9,6 +9,7 @@ import mlx.nn as nn from huggingface_hub import snapshot_download +from ..base import BaseModel from .language import LanguageModel, TextConfig from .vision import VisionConfig, VisionModel @@ -49,7 +50,7 @@ def __call__(self, x: mx.array) -> mx.array: return output -class Model(nn.Module): +class Model(BaseModel): def __init__(self, config: ModelConfig): super().__init__() self.model_type = config.model_type @@ -64,18 +65,32 @@ def get_input_embeddings( input_ids: Optional[mx.array] = None, pixel_values: Optional[mx.array] = None, mask: Optional[mx.array] = None, + **kwargs, ): if pixel_values is None: return self.language_model.model.embed_tokens(input_ids), None inputs_embeds = self.language_model.model.embed_tokens(input_ids) - hidden_state, _, _ = self.vision_tower( + hidden_state, _, _, all_attns = self.vision_tower( pixel_values.transpose(0, 2, 3, 1).astype(inputs_embeds.dtype), output_hidden_states=True, + output_attentions=True, ) image_features = hidden_state[None, :].astype(pixel_values.dtype) + + if all_attns: + attn = all_attns[-1] + vision_filter_ratio = kwargs.get("vision_filter_ratio", 1.0) + vision_merge_ratio = kwargs.get("vision_merge_ratio", 1.0) + image_features = self.filter_topk_vision_tokens( + image_features, attn, vision_filter_ratio + ) + image_features = self.merge_similar_vision_tokens( + image_features, vision_merge_ratio + ) + image_features = self.multi_modal_projector(image_features) final_inputs_embeds, final_attention_mask_4d = ( @@ -143,7 +158,7 @@ def __call__( **kwargs, ): input_embeddings, final_attention_mask_4d = self.get_input_embeddings( - input_ids, pixel_values, mask + input_ids, pixel_values, mask, **kwargs ) logits = self.language_model( diff --git a/mlx_vlm/models/paligemma/vision.py b/mlx_vlm/models/paligemma/vision.py index d97ed080f..55d9af8bc 100644 --- a/mlx_vlm/models/paligemma/vision.py +++ b/mlx_vlm/models/paligemma/vision.py @@ -94,11 +94,11 @@ def __call__(self, x, mask=None): keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) - output = mx.fast.scaled_dot_product_attention( + attn = mx.fast.scaled_dot_product_attention( queries, keys, values, scale=self.scale, mask=mask ) - output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.out_proj(output) + output = attn.transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.out_proj(output), attn class MLP(nn.Module): @@ -127,10 +127,10 @@ def __init__(self, config: VisionConfig): self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array: - r = self.self_attn(self.layer_norm1(x), mask) + r, attn = self.self_attn(self.layer_norm1(x), mask) h = x + r r = self.mlp(self.layer_norm2(h)) - return h + r + return h + r, attn class Encoder(nn.Module): @@ -141,19 +141,23 @@ def __init__(self, config: VisionConfig): def __call__( self, x: mx.array, - output_hidden_states: Optional[bool] = None, mask: Optional[mx.array] = None, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, ) -> mx.array: encoder_states = (x,) if output_hidden_states else None + all_attns = () if output_attentions else None h = x for l in self.layers: - x = l(x, mask=mask) + x, attn = l(x, mask=mask) if output_hidden_states: encoder_states = encoder_states + (x,) + if output_attentions: + all_attns = all_attns + (attn,) h = x[0] - return (h, encoder_states) + return (h, encoder_states, all_attns) class VisionEmbeddings(nn.Module): @@ -195,16 +199,17 @@ def __call__( self, x: mx.array, output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, ) -> mx.array: x = self.embeddings(x) - encoder_outputs = self.encoder( + h, encoder_states, all_attns = self.encoder( x=x, output_hidden_states=output_hidden_states, mask=None ) - pooler_output = self.post_layernorm(encoder_outputs[0]) + pooler_output = self.post_layernorm(h) - return pooler_output, x, encoder_outputs[-1] + return pooler_output, x, encoder_states, all_attns class VisionModel(nn.Module): @@ -217,9 +222,12 @@ def __init__(self, config: VisionConfig): self.vision_model = SigLipVisionModel(config) def __call__( - self, x: mx.array, output_hidden_states: Optional[bool] = None + self, + x: mx.array, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, ) -> mx.array: - return self.vision_model(x, output_hidden_states) + return self.vision_model(x, output_hidden_states, output_attentions) def sanitize(self, weights): sanitized_weights = {} From 19d6e7e5d1cf565e709b25580f2e15908a8d1cc0 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Mon, 27 Jan 2025 20:28:08 +0100 Subject: [PATCH 12/19] stash progress --- mlx_vlm/models/base.py | 21 ++++++++++--- mlx_vlm/models/molmo/molmo.py | 27 ++++++++++++++-- mlx_vlm/models/molmo/vision.py | 56 ++++++++++++++++++++++++---------- 3 files changed, 82 insertions(+), 22 deletions(-) diff --git a/mlx_vlm/models/base.py b/mlx_vlm/models/base.py index ddb7307e0..d14cf5d6a 100644 --- a/mlx_vlm/models/base.py +++ b/mlx_vlm/models/base.py @@ -310,7 +310,7 @@ def filter_topk_vision_tokens(self, image_feature, attn, vision_filter_ratio=Non dominant_idx = mx.concatenate([cls_indices, topk_idx + cls_idx + 1], axis=1) image_feature = mx.take(image_feature, dominant_idx, axis=1)[0] - return image_feature + return image_feature, dominant_idx def merge_similar_vision_tokens( self, image_feature, vision_merge_ratio, merge_rate=0.4 @@ -322,6 +322,9 @@ def merge_similar_vision_tokens( # Calculate target number of tokens target_tokens = max(1, int(num_tokens * vision_merge_ratio)) + # Create a mask of the same shape as tokens, initialized to True + mask = mx.ones((batch_size, num_tokens)) + while num_tokens > target_tokens: # Calculate similarities between adjacent tokens tokens_a = tokens[:, :-1] # all except last @@ -343,12 +346,14 @@ def merge_similar_vision_tokens( # Merge selected pairs new_tokens = [] + new_mask = [] i = 0 while i < num_tokens: if i < num_tokens - 1 and i in to_merge: # Merge this token with the next one merged = (tokens[:, i : i + 1] + tokens[:, i + 1 : i + 2]) / 2 new_tokens.append(merged) + new_mask.append(mask[:, i : i + 1]) # Keep mask from first token i += 2 elif i > 0 and (i - 1) in to_merge: # Skip this token as it was merged in the previous step @@ -356,11 +361,19 @@ def merge_similar_vision_tokens( else: # Keep this token as is new_tokens.append(tokens[:, i : i + 1]) + new_mask.append(mask[:, i : i + 1]) i += 1 - # Update tokens + # Update tokens and mask tokens = mx.concatenate(new_tokens, axis=1) + mask = mx.concatenate(new_mask, axis=1) num_tokens = tokens.shape[1] - # Reattach CLS token - return mx.concatenate([image_feature[:, :1], tokens], axis=1) + # Add back CLS token + cls_mask = mx.ones((batch_size, 1), dtype=mx.bool_) + return mx.concatenate([image_feature[:, :1], tokens], axis=1), mx.concatenate( + [cls_mask, mask], axis=1 + ) + + def merge_vision_patches(self, image_feature, vision_merge_ratio, merge_rate=0.4): + pass diff --git a/mlx_vlm/models/molmo/molmo.py b/mlx_vlm/models/molmo/molmo.py index e3e345d8f..37fea1435 100644 --- a/mlx_vlm/models/molmo/molmo.py +++ b/mlx_vlm/models/molmo/molmo.py @@ -10,6 +10,7 @@ import numpy as np from huggingface_hub import snapshot_download +from ..base import BaseModel from .language import LanguageModel, TextConfig from .vision import VisionConfig, VisionModel @@ -36,7 +37,7 @@ def from_dict(cls, params): ) -class Model(nn.Module): +class Model(BaseModel): def __init__(self, config: ModelConfig): super().__init__() self.config = config @@ -79,7 +80,9 @@ def __call__( else None ) - image_features, cls_embed = self.vision_tower(pixel_values, image_masks) + image_features, cls_embed, all_attns = self.vision_tower( + pixel_values, image_masks, output_attentions=True + ) # Insert image features into the input embeddings num_image, num_patch = image_features.shape[1:3] @@ -94,8 +97,28 @@ def __call__( image_features = image_features.reshape( batch_size, num_image * num_patch, -1 ) + image_input_idx = image_input_idx.reshape(batch_size, num_image * num_patch) + all_attns = all_attns.reshape(batch_size, num_image * num_patch, -1) + print("all_attns", all_attns.shape) + print("image_features", image_features.shape) + print("image_input_idx", image_input_idx.shape) + if all_attns is not None: + attn = all_attns[None, ...] + vision_filter_ratio = kwargs.get("vision_filter_ratio", 1.0) + vision_merge_ratio = kwargs.get("vision_merge_ratio", 1.0) + image_features, filter_mask = self.filter_topk_vision_tokens( + image_features, attn, vision_filter_ratio + ) + image_input_idx = mx.take(image_input_idx, filter_mask[0], axis=1) + + image_features, merge_mask = self.merge_similar_vision_tokens( + image_features, vision_merge_ratio + ) + merge_mask = mx.array(np.where(merge_mask > 0)[1]) + image_input_idx = mx.take(image_input_idx, merge_mask, axis=1) + valid = np.where(image_input_idx >= 0)[0].tolist() batch_idx = mx.arange(batch_size) batch_idx = mx.tile(batch_idx[:, None], [1, image_features.shape[1]]) diff --git a/mlx_vlm/models/molmo/vision.py b/mlx_vlm/models/molmo/vision.py index d42ae47f8..d84628dca 100644 --- a/mlx_vlm/models/molmo/vision.py +++ b/mlx_vlm/models/molmo/vision.py @@ -165,7 +165,7 @@ def __call__(self, x: mx.array, kv: mx.array = None) -> mx.array: out = attn.transpose(0, 2, 1, 3) out = self._merge_heads(out) out = self.wo(out) - return out + return out, attn class ResidualAttentionBlock(nn.Module): @@ -180,9 +180,10 @@ def __init__(self, config: VisionConfig): self.ffn_norm = nn.LayerNorm(config.image_emb_dim, eps=config.image_norm_eps) def __call__(self, x: mx.array) -> mx.array: - x = x + self.attention(self.attention_norm(x)) + h, attn = self.attention(self.attention_norm(x)) + x = x + h x = x + self.feed_forward(self.ffn_norm(x)) - return x + return x, attn class ResidualAttentionBlocks(nn.Module): @@ -194,10 +195,12 @@ def __init__(self, config: VisionConfig): def __call__(self, x: mx.array) -> mx.array: h = [] + attns = [] for block in self.resblocks: - x = block(x) + x, attn = block(x) h.append(x) - return h + attns.append(attn) + return h, attns def _expand_token(token, batch_size: int): @@ -320,8 +323,8 @@ def __call__(self, x: mx.array, patch_num: int = None) -> List[mx.array]: x = self.pre_ln(x) - hidden_states = self.transformer(x) - return hidden_states + hidden_states, attn = self.transformer(x) + return hidden_states, attn class VisionModel(nn.Module): @@ -340,7 +343,9 @@ def __init__(self, config): self.image_projector = MLP(config, config.image_emb_dim) self.pad_embed = mx.zeros((2, config.image_emb_dim * 2)) - def encode_image(self, images: mx.array) -> mx.array: + def encode_image( + self, images: mx.array, output_attentions: Optional[bool] = False + ) -> mx.array: """ : param images: (batch_size, num_crops, num_patch, n_pixels) """ @@ -353,7 +358,7 @@ def encode_image(self, images: mx.array) -> mx.array: # Output all hidden states images = reshaped_images - image_features = self.image_vit(images) + image_features, all_attns = self.image_vit(images) if cfg.vit_layers is not None: features = [] @@ -373,15 +378,21 @@ def encode_image(self, images: mx.array) -> mx.array: cls_embed = mx.reshape(cls_embed, (B, T, -1)) if cls_embed is not None else None - return image_features, cls_embed + if output_attentions: + return image_features, cls_embed, all_attns + else: + return image_features, cls_embed def __call__( - self, images: mx.array, image_masks: mx.array + self, + images: mx.array, + image_masks: mx.array, + output_attentions: Optional[bool] = None, ) -> Tuple[mx.array, Optional[mx.array]]: cfg = self.config batch_size, num_image = images.shape[:2] - image_features, cls_embed = self.encode_image(images) + image_features, cls_embed, _ = self.encode_image(images, output_attentions=True) if cfg.image_padding_embed: assert image_masks is not None @@ -448,18 +459,31 @@ def __call__( ), ) + all_attns = None if cfg.image_pooling_2d == "attention-meanq": query = mx.mean(image_features, axis=-2, keepdims=True) - image_features = self.image_pooling_2d(query, image_features) + image_features, all_attns = self.image_pooling_2d(query, image_features) elif cfg.image_pooling_2d not in {"none", "stack"}: - image_features = self.image_pooling_2d( + image_features, all_attns = self.image_pooling_2d( image_features[:, :1, :], image_features ) + if all_attns is None: + raise ValueError("Attention is None") + h, w = cfg.llm_patches_per_crop + image_features = mx.reshape(image_features, (batch_size, num_image, h * w, -1)) - # # MLP layer to map the feature + all_attns = all_attns.reshape( + all_attns.shape[0], -1, all_attns.shape[1] * all_attns.shape[-1] + ) + all_attns = mx.reshape(all_attns, (batch_size, num_image, h * w, -1)) + + # MLP layer to map the feature image_features = self.image_projector(image_features) - return image_features, cls_embed + if output_attentions: + return image_features, cls_embed, all_attns + else: + return image_features, cls_embed From f65b985f4f715ecdf5ad4642e297e661024ecd22 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sun, 2 Mar 2025 18:46:26 +0100 Subject: [PATCH 13/19] fix video generate --- mlx_vlm/video_generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlx_vlm/video_generate.py b/mlx_vlm/video_generate.py index 81716c58e..98dc200bc 100644 --- a/mlx_vlm/video_generate.py +++ b/mlx_vlm/video_generate.py @@ -466,7 +466,7 @@ def main(): "Warning: The model selected doesn't natively support video inputs. Performance may be degraded." ) - if isinstance(args.max_pixels, tuple): + if isinstance(args.max_pixels, tuple) or isinstance(args.max_pixels, list): max_pixels = args.max_pixels[0] * args.max_pixels[1] else: max_pixels = args.max_pixels From 5b4127b0d55f20c24be2eb8902a9cd53d2ae9328 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sun, 2 Mar 2025 18:46:45 +0100 Subject: [PATCH 14/19] update temp --- README.md | 2 +- mlx_vlm/generate.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index f71a942cf..4a3041774 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,7 @@ pip install mlx-vlm Generate output from a model using the CLI: ```sh -python -m mlx_vlm.generate --model mlx-community/Qwen2-VL-2B-Instruct-4bit --max-tokens 100 --temp 0.0 --image http://images.cocodataset.org/val2017/000000039769.jpg +python -m mlx_vlm.generate --model mlx-community/Qwen2-VL-2B-Instruct-4bit --max-tokens 100 --temperature 0.0 --image http://images.cocodataset.org/val2017/000000039769.jpg ``` ### Chat UI with Gradio diff --git a/mlx_vlm/generate.py b/mlx_vlm/generate.py index 748f1c264..470679908 100644 --- a/mlx_vlm/generate.py +++ b/mlx_vlm/generate.py @@ -15,7 +15,7 @@ DEFAULT_IMAGE = [] DEFAULT_PROMPT = "What are these?" DEFAULT_MAX_TOKENS = 256 -DEFAULT_TEMP = 0.5 +DEFAULT_TEMPERATURE = 0.5 DEFAULT_TOP_P = 1.0 DEFAULT_SEED = 0 @@ -71,7 +71,7 @@ def parse_arguments(): parser.add_argument( "--temperature", type=float, - default=DEFAULT_TEMP, + default=DEFAULT_TEMPERATURE, help="Temperature for sampling.", ) parser.add_argument("--chat", action="store_true", help="Chat in multi-turn style.") From 8e730daf86480cdd9e100e147ea3286d72eccb28 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sun, 2 Mar 2025 18:58:24 +0100 Subject: [PATCH 15/19] Add token merging and filtering --- mlx_vlm/models/base.py | 16 +++++++++--- .../models/deepseek_vl_v2/deepseek_vl_v2.py | 4 +-- mlx_vlm/models/florence2/florence2.py | 2 +- mlx_vlm/models/idefics2/idefics2.py | 2 +- mlx_vlm/models/idefics3/idefics3.py | 3 ++- mlx_vlm/models/llava/llava.py | 3 ++- mlx_vlm/models/llava_bunny/llava_bunny.py | 4 +-- mlx_vlm/models/llava_next/llava_next.py | 23 ++++++++++++++--- mlx_vlm/models/llava_next/vision.py | 9 ++++--- mlx_vlm/models/mllama/mllama.py | 4 +-- mlx_vlm/models/molmo/molmo.py | 8 +++--- .../models/multi_modality/multi_modality.py | 4 +-- mlx_vlm/models/pixtral/pixtral.py | 2 +- mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py | 25 ++++++++++++++++--- mlx_vlm/models/qwen2_5_vl/vision.py | 20 +++++++++------ 15 files changed, 91 insertions(+), 38 deletions(-) diff --git a/mlx_vlm/models/base.py b/mlx_vlm/models/base.py index d14cf5d6a..9604fc4bb 100644 --- a/mlx_vlm/models/base.py +++ b/mlx_vlm/models/base.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple import mlx.core as mx import mlx.nn as nn @@ -284,7 +284,9 @@ def __init__(self): self.vision_tower = None self.language_model = None - def filter_topk_vision_tokens(self, image_feature, attn, vision_filter_ratio=None): + def filter_topk_vision_tokens( + self, image_feature, attn, vision_filter_ratio=None + ) -> Tuple[mx.array, mx.array]: batch_size, seq_len = image_feature.shape[:2] k_tokens = ( int(image_feature.shape[1] * vision_filter_ratio) @@ -293,7 +295,7 @@ def filter_topk_vision_tokens(self, image_feature, attn, vision_filter_ratio=Non ) # keep 25% of the visual tokens if k_tokens is None or k_tokens == seq_len: - return image_feature + return image_feature, None cls_idx = 0 # self.config.image_token_index @@ -314,7 +316,7 @@ def filter_topk_vision_tokens(self, image_feature, attn, vision_filter_ratio=Non def merge_similar_vision_tokens( self, image_feature, vision_merge_ratio, merge_rate=0.4 - ): + ) -> Tuple[mx.array, mx.array]: # Skip CLS token (first token) tokens = image_feature[:, 1:] batch_size, num_tokens, hidden_dim = tokens.shape @@ -322,6 +324,9 @@ def merge_similar_vision_tokens( # Calculate target number of tokens target_tokens = max(1, int(num_tokens * vision_merge_ratio)) + if num_tokens == target_tokens: + return image_feature, None + # Create a mask of the same shape as tokens, initialized to True mask = mx.ones((batch_size, num_tokens)) @@ -376,4 +381,7 @@ def merge_similar_vision_tokens( ) def merge_vision_patches(self, image_feature, vision_merge_ratio, merge_rate=0.4): + """ + Merge vision patches based on the vision_merge_ratio and merge_rate. + """ pass diff --git a/mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py b/mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py index 3629bcc44..2849168cf 100644 --- a/mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +++ b/mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py @@ -15,7 +15,7 @@ from transformers.image_processing_utils import BaseImageProcessor, BatchFeature from transformers.image_utils import to_numpy_array -from ..base import expand2square +from ..base import BaseModel, expand2square from .language import LanguageModel, TextConfig from .processing_deepsek_vl_v2 import DeepseekVLV2Processor from .vision import VisionConfig, VisionModel @@ -194,7 +194,7 @@ def __call__(self, x): return x -class Model(nn.Module): +class Model(BaseModel): def __init__(self, config: ModelConfig): super().__init__() self.config = config diff --git a/mlx_vlm/models/florence2/florence2.py b/mlx_vlm/models/florence2/florence2.py index 7895c4e34..60ebe8ce0 100644 --- a/mlx_vlm/models/florence2/florence2.py +++ b/mlx_vlm/models/florence2/florence2.py @@ -156,7 +156,7 @@ def __call__(self, seq_embeds: mx.array) -> mx.array: return pos_embeds -class Model(nn.Module): +class Model(BaseModel): """Florence-2 model for conditional generation.""" def __init__(self, config: ModelConfig): diff --git a/mlx_vlm/models/idefics2/idefics2.py b/mlx_vlm/models/idefics2/idefics2.py index 294dd9a13..1f7426891 100644 --- a/mlx_vlm/models/idefics2/idefics2.py +++ b/mlx_vlm/models/idefics2/idefics2.py @@ -200,7 +200,7 @@ def __call__(self, x: mx.array, mask=None) -> mx.array: return x -class Model(nn.Module): +class Model(BaseModel): def __init__(self, config: ModelConfig): super().__init__() self.model_type = config.model_type diff --git a/mlx_vlm/models/idefics3/idefics3.py b/mlx_vlm/models/idefics3/idefics3.py index 6b24722e5..b83c48cb2 100644 --- a/mlx_vlm/models/idefics3/idefics3.py +++ b/mlx_vlm/models/idefics3/idefics3.py @@ -12,6 +12,7 @@ from huggingface_hub import snapshot_download from transformers import AutoConfig +from ..base import BaseModel from .language import LanguageModel, TextConfig from .vision import VisionConfig, VisionModel @@ -81,7 +82,7 @@ def __call__(self, image_hidden_states): return image_hidden_states -class Model(nn.Module): +class Model(BaseModel): def __init__(self, config: ModelConfig): super().__init__() self.model_type = config.model_type diff --git a/mlx_vlm/models/llava/llava.py b/mlx_vlm/models/llava/llava.py index 50041510f..677a44c4b 100644 --- a/mlx_vlm/models/llava/llava.py +++ b/mlx_vlm/models/llava/llava.py @@ -10,6 +10,7 @@ import numpy as np from huggingface_hub import snapshot_download +from ..base import BaseModel from .language import LanguageModel, TextConfig from .vision import VisionConfig, VisionModel @@ -54,7 +55,7 @@ def __call__(self, x: mx.array) -> mx.array: return x -class Model(nn.Module): +class Model(BaseModel): def __init__(self, config: ModelConfig): super().__init__() self.config = config diff --git a/mlx_vlm/models/llava_bunny/llava_bunny.py b/mlx_vlm/models/llava_bunny/llava_bunny.py index 5363752f2..b90db8c25 100644 --- a/mlx_vlm/models/llava_bunny/llava_bunny.py +++ b/mlx_vlm/models/llava_bunny/llava_bunny.py @@ -22,7 +22,7 @@ ) from transformers.image_utils import to_numpy_array -from ..base import BaseImageProcessor +from ..base import BaseImageProcessor, BaseModel from .language import LanguageModel, TextConfig from .vision import VisionConfig, VisionModel @@ -125,7 +125,7 @@ def __call__( return self.vision_tower(x, output_hidden_states) -class Model(nn.Module): +class Model(BaseModel): def __init__(self, config: ModelConfig): super().__init__() self.model_type = config.model_type diff --git a/mlx_vlm/models/llava_next/llava_next.py b/mlx_vlm/models/llava_next/llava_next.py index 20a43137b..1505e7473 100644 --- a/mlx_vlm/models/llava_next/llava_next.py +++ b/mlx_vlm/models/llava_next/llava_next.py @@ -10,6 +10,7 @@ import numpy as np from huggingface_hub import snapshot_download +from ..base import BaseModel from .language import LanguageModel, TextConfig from .vision import VisionConfig, VisionModel @@ -54,7 +55,7 @@ def __call__(self, x: mx.array) -> mx.array: return x -class Model(nn.Module): +class Model(BaseModel): def __init__(self, config: ModelConfig): super().__init__() self.config = config @@ -73,6 +74,7 @@ def get_input_embeddings( self, input_ids: Optional[mx.array] = None, pixel_values: Optional[mx.array] = None, + **kwargs, ): if pixel_values is None: return self.language_model.model.embed_tokens(input_ids) @@ -81,8 +83,10 @@ def get_input_embeddings( inputs_embeds = self.language_model.model.embed_tokens(input_ids) # Get the ouptut hidden states from the vision model - *_, hidden_states = self.vision_tower( - pixel_values[0].transpose(0, 2, 3, 1), output_hidden_states=True + *_, hidden_states, all_attns = self.vision_tower( + pixel_values[0].transpose(0, 2, 3, 1), + output_hidden_states=True, + output_attn=True, ) # Select the hidden states from the desired layer @@ -98,6 +102,17 @@ def get_input_embeddings( f"{self.vision_feature_select_strategy}" ) + if all_attns: + attn = all_attns[-1] + vision_filter_ratio = kwargs.get("vision_filter_ratio", 1.0) + vision_merge_ratio = kwargs.get("vision_merge_ratio", 1.0) + selected_image_feature, _ = self.filter_topk_vision_tokens( + selected_image_feature, attn, vision_filter_ratio + ) + selected_image_feature, _ = self.merge_similar_vision_tokens( + selected_image_feature, vision_merge_ratio + ) + # Pass image features through the multi-modal projector image_features = self.multi_modal_projector(selected_image_feature) @@ -148,7 +163,7 @@ def __call__( **kwargs, ): - input_embddings = self.get_input_embeddings(input_ids, pixel_values) + input_embddings = self.get_input_embeddings(input_ids, pixel_values, **kwargs) logits = self.language_model( input_ids, cache=cache, inputs_embeds=input_embddings ) diff --git a/mlx_vlm/models/llava_next/vision.py b/mlx_vlm/models/llava_next/vision.py index 5992c97c3..3443c146f 100644 --- a/mlx_vlm/models/llava_next/vision.py +++ b/mlx_vlm/models/llava_next/vision.py @@ -208,7 +208,7 @@ def __call__( all_attns = all_attns + (attn,) pooler_output = self.post_layernorm(x[:, 0, :]) - return pooler_output, x, encoder_states, attns + return pooler_output, x, encoder_states, all_attns class VisionModel(nn.Module): @@ -222,9 +222,12 @@ def __init__(self, config: VisionConfig): self.vision_model = ClipVisionModel(config) def __call__( - self, x: mx.array, output_hidden_states: Optional[bool] = None + self, + x: mx.array, + output_hidden_states: Optional[bool] = None, + output_attn: bool = False, ) -> mx.array: - return self.vision_model(x, output_hidden_states) + return self.vision_model(x, output_hidden_states, output_attn) def sanitize(self, weights): sanitized_weights = {} diff --git a/mlx_vlm/models/mllama/mllama.py b/mlx_vlm/models/mllama/mllama.py index 4bb8bc29c..b51e18563 100644 --- a/mlx_vlm/models/mllama/mllama.py +++ b/mlx_vlm/models/mllama/mllama.py @@ -9,7 +9,7 @@ import mlx.nn as nn from huggingface_hub import snapshot_download -from ..base import KVCache +from ..base import BaseModel, KVCache from .language import LanguageModel, TextConfig from .vision import VisionConfig, VisionModel @@ -36,7 +36,7 @@ def from_dict(cls, params): ) -class Model(nn.Module): +class Model(BaseModel): def __init__(self, config: ModelConfig): super().__init__() self.config = config diff --git a/mlx_vlm/models/molmo/molmo.py b/mlx_vlm/models/molmo/molmo.py index 37fea1435..fd2ec1fa0 100644 --- a/mlx_vlm/models/molmo/molmo.py +++ b/mlx_vlm/models/molmo/molmo.py @@ -113,9 +113,11 @@ def __call__( ) image_input_idx = mx.take(image_input_idx, filter_mask[0], axis=1) - image_features, merge_mask = self.merge_similar_vision_tokens( - image_features, vision_merge_ratio - ) + if vision_merge_ratio < 1: + print( + "Operation skipped: Molmo architecture does not support vision token merging" + ) + merge_mask = mx.array(np.where(merge_mask > 0)[1]) image_input_idx = mx.take(image_input_idx, merge_mask, axis=1) diff --git a/mlx_vlm/models/multi_modality/multi_modality.py b/mlx_vlm/models/multi_modality/multi_modality.py index d512abcd8..ed4099e4b 100644 --- a/mlx_vlm/models/multi_modality/multi_modality.py +++ b/mlx_vlm/models/multi_modality/multi_modality.py @@ -13,7 +13,7 @@ from transformers.image_processing_utils import BatchFeature from transformers.image_utils import to_numpy_array -from ..base import BaseImageProcessor, expand2square +from ..base import BaseImageProcessor, BaseModel, expand2square from .language import LanguageModel, TextConfig from .vision import VisionConfig, VisionModel @@ -240,7 +240,7 @@ def __call__(self, x: Union[mx.array, Tuple]) -> mx.array: return x -class Model(nn.Module): +class Model(BaseModel): def __init__(self, config: ModelConfig): super().__init__() self.config = config diff --git a/mlx_vlm/models/pixtral/pixtral.py b/mlx_vlm/models/pixtral/pixtral.py index c64bb8bfd..9e1ff7cd9 100644 --- a/mlx_vlm/models/pixtral/pixtral.py +++ b/mlx_vlm/models/pixtral/pixtral.py @@ -100,7 +100,7 @@ def get_input_embeddings( attn = all_attns[self.vision_feature_layer] vision_filter_ratio = kwargs.get("vision_filter_ratio", 1.0) vision_merge_ratio = kwargs.get("vision_merge_ratio", 1.0) - selected_image_feature = self.filter_topk_vision_tokens( + selected_image_feature, _ = self.filter_topk_vision_tokens( selected_image_feature, attn, vision_filter_ratio ) selected_image_feature = self.merge_similar_vision_tokens( diff --git a/mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py b/mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py index 31f7e2999..15b61f72a 100644 --- a/mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +++ b/mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py @@ -10,6 +10,7 @@ import numpy as np from huggingface_hub import snapshot_download +from ..base import BaseModel from .language import LanguageModel, TextConfig from .vision import VisionConfig, VisionModel @@ -46,7 +47,7 @@ def from_dict(cls, params): ) -class Model(nn.Module): +class Model(BaseModel): def __init__(self, config: ModelConfig): super().__init__() self.config = config @@ -58,6 +59,7 @@ def get_input_embeddings( input_ids: Optional[mx.array] = None, pixel_values: Optional[mx.array] = None, image_grid_thw: Optional[mx.array] = None, + **kwargs, ): if pixel_values is None: @@ -70,13 +72,25 @@ def get_input_embeddings( inputs_embeds = self.language_model.model.embed_tokens(input_ids) # Get the ouptut hidden states from the vision model - hidden_states = self.vision_tower( - pixel_values, image_grid_thw, output_hidden_states=False + hidden_states, all_attns = self.vision_tower( + pixel_values, image_grid_thw, output_hidden_states=False, output_attn=True ) if hidden_states.ndim == 2: hidden_states = hidden_states[None, :, :] + if all_attns: + attn = all_attns[-1] + + vision_filter_ratio = kwargs.get("vision_filter_ratio", 1.0) + vision_merge_ratio = kwargs.get("vision_merge_ratio", 1.0) + hidden_states, _ = self.filter_topk_vision_tokens( + hidden_states, attn, vision_filter_ratio + ) + hidden_states, _ = self.merge_similar_vision_tokens( + hidden_states, vision_merge_ratio + ) + # Insert special image tokens in the input_ids final_inputs_embeds = self._merge_input_ids_with_image_features( hidden_states, inputs_embeds, input_ids @@ -88,6 +102,7 @@ def _merge_input_ids_with_image_features( ): image_token_id = self.config.image_token_id video_token_id = self.config.video_token_id + # Positions of tokens in input_ids, assuming batch size is 1 image_positions = input_ids == image_token_id if mx.sum(image_positions) == 0: @@ -114,7 +129,9 @@ def __call__( if image_grid_thw is not None: image_grid_thw = mx.array(image_grid_thw) - inputs_embeds = self.get_input_embeddings(input_ids, pixel_values, grid_thw) + inputs_embeds = self.get_input_embeddings( + input_ids, pixel_values, grid_thw, **kwargs + ) logits = self.language_model(None, cache=cache, inputs_embeds=inputs_embeds) return logits diff --git a/mlx_vlm/models/qwen2_5_vl/vision.py b/mlx_vlm/models/qwen2_5_vl/vision.py index a40da4613..15d08c203 100644 --- a/mlx_vlm/models/qwen2_5_vl/vision.py +++ b/mlx_vlm/models/qwen2_5_vl/vision.py @@ -187,12 +187,12 @@ def __call__( k = k.transpose(0, 2, 1, 3) v = v.transpose(0, 2, 1, 3) - output = mx.fast.scaled_dot_product_attention( + attn = mx.fast.scaled_dot_product_attention( q, k, v, scale=self.scale, mask=attention_mask ) - output = output.transpose(0, 2, 1, 3) + output = attn.transpose(0, 2, 1, 3) output = output.reshape(seq_length, -1) - return self.proj(output) + return self.proj(output), attn class MLP(nn.Module): @@ -216,13 +216,14 @@ def __init__(self, config: VisionConfig) -> None: self.mlp = MLP(dim=config.hidden_size, hidden_dim=config.intermediate_size) def __call__(self, hidden_states, cu_seqlens, rotary_pos_emb) -> mx.array: - hidden_states = hidden_states + self.attn( + x, attn = self.attn( self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, ) + hidden_states += x hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) - return hidden_states + return hidden_states, attn class VisionModel(nn.Module): @@ -362,6 +363,7 @@ def __call__( hidden_states: mx.array, grid_thw: mx.array, output_hidden_states: Optional[bool] = None, + output_attn: Optional[bool] = None, ) -> mx.array: hidden_states = self.patch_embed(hidden_states) @@ -407,6 +409,7 @@ def __call__( cu_seqlens = mx.pad(cu_seqlens, (1, 0), mode="constant", constant_values=0) encoder_states = (hidden_states,) if output_hidden_states else None + all_attns = () if output_attn else None for layer_num, blk in enumerate(self.blocks): if layer_num in self.fullatt_block_indexes: @@ -414,17 +417,20 @@ def __call__( else: cu_seqlens_now = cu_window_seqlens - hidden_states = blk( + hidden_states, attn = blk( hidden_states, cu_seqlens=cu_seqlens_now, rotary_pos_emb=rotary_pos_emb ) if output_hidden_states: encoder_states = encoder_states + (hidden_states,) + if output_attn: + all_attns = all_attns + (attn,) + hidden_states = self.merger(hidden_states) reverse_indices = mx.argsort(window_index, axis=0) hidden_states = hidden_states[reverse_indices, :] - return hidden_states + return hidden_states, all_attns def sanitize(self, weights): sanitized_weights = {} From 1f47d4fcec12199a7bfac628c616f2b0fbc30a9a Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sun, 2 Mar 2025 21:18:25 +0100 Subject: [PATCH 16/19] fix vision tests --- mlx_vlm/models/base.py | 8 +++++ .../models/deepseek_vl_v2/deepseek_vl_v2.py | 3 +- mlx_vlm/models/deepseek_vl_v2/vision.py | 6 +++- mlx_vlm/models/florence2/florence2.py | 4 ++- mlx_vlm/models/florence2/vision.py | 4 ++- mlx_vlm/models/idefics2/idefics2.py | 5 +-- mlx_vlm/models/idefics2/vision.py | 6 +++- mlx_vlm/models/idefics3/idefics3.py | 3 +- mlx_vlm/models/idefics3/vision.py | 6 +++- mlx_vlm/models/llava/llava.py | 3 +- mlx_vlm/models/llava/vision.py | 16 +++++++--- mlx_vlm/models/llava_bunny/llava_bunny.py | 4 +-- mlx_vlm/models/llava_bunny/vision.py | 16 +++++++--- mlx_vlm/models/llava_next/llava_next.py | 6 ++-- mlx_vlm/models/llava_next/vision.py | 20 ++++++++---- mlx_vlm/models/mllama/mllama.py | 4 +-- mlx_vlm/models/mllama/vision.py | 4 ++- mlx_vlm/models/molmo/molmo.py | 8 ++--- mlx_vlm/models/molmo/vision.py | 7 ++-- mlx_vlm/models/multi_modality/vision.py | 7 ++-- mlx_vlm/models/paligemma/paligemma.py | 5 ++- mlx_vlm/models/paligemma/vision.py | 11 +++++-- mlx_vlm/models/phi3_v/vision.py | 9 ++++-- mlx_vlm/models/pixtral/pixtral.py | 6 +++- mlx_vlm/models/pixtral/vision.py | 8 ++++- mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py | 10 ++++-- mlx_vlm/models/qwen2_5_vl/vision.py | 12 ++++--- mlx_vlm/models/qwen2_vl/qwen2_vl.py | 14 +++++--- mlx_vlm/models/qwen2_vl/vision.py | 14 +++++--- mlx_vlm/tests/test_models.py | 32 +++++++++++++++++-- mlx_vlm/tests/test_smoke.py | 6 ++++ 31 files changed, 197 insertions(+), 70 deletions(-) diff --git a/mlx_vlm/models/base.py b/mlx_vlm/models/base.py index 9604fc4bb..f581e1ba2 100644 --- a/mlx_vlm/models/base.py +++ b/mlx_vlm/models/base.py @@ -278,6 +278,14 @@ class LanguageModelOutput: encoder_outputs: Optional[List[mx.array]] = None +@dataclass +class VisionModelOutput: + hidden_states: Optional[mx.array] = None + encoder_states: Optional[List[mx.array]] = None + attentions: Optional[List[mx.array]] = None + pooler_output: Optional[mx.array] = None + + class BaseModel(nn.Module): def __init__(self): super().__init__() diff --git a/mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py b/mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py index 2849168cf..1283e232f 100644 --- a/mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +++ b/mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py @@ -409,9 +409,10 @@ def get_input_embeddings( input_embeds = self.language_model.model.embed_tokens(input_ids) # Get the ouptut hidden states from the vision model - hidden_states, *_ = self.vision( + vision_output = self.vision( total_tiles.transpose(0, 2, 3, 1), output_hidden_states=True ) + hidden_states = vision_output.encoder_states # Pass image features through the multi-modal projector image_features = self.projector(hidden_states) diff --git a/mlx_vlm/models/deepseek_vl_v2/vision.py b/mlx_vlm/models/deepseek_vl_v2/vision.py index 8ad8b4c77..c8dce86af 100644 --- a/mlx_vlm/models/deepseek_vl_v2/vision.py +++ b/mlx_vlm/models/deepseek_vl_v2/vision.py @@ -9,6 +9,8 @@ import mlx.nn as nn import numpy as np +from ..base import VisionModelOutput + @dataclass class VisionConfig: @@ -305,7 +307,9 @@ def __call__( if not self.ignore_head: pooler_output = self.attn_pool(pooler_output) - return pooler_output, x, encoder_states + return VisionModelOutput( + pooler_output=pooler_output, encoder_states=encoder_states + ) class VisionModel(nn.Module): diff --git a/mlx_vlm/models/florence2/florence2.py b/mlx_vlm/models/florence2/florence2.py index 60ebe8ce0..91e3f3704 100644 --- a/mlx_vlm/models/florence2/florence2.py +++ b/mlx_vlm/models/florence2/florence2.py @@ -11,6 +11,7 @@ from huggingface_hub import snapshot_download from mlx.utils import tree_map +from ..base import BaseModel from .language import LanguageModel, TextConfig from .vision import VisionConfig, VisionModel @@ -207,7 +208,8 @@ def _encode_image(self, pixel_values, extract_features=True): # Get vision features if extract_features: batch_size, C, H, W = pixel_values.shape - x = self.vision_tower(pixel_values) + vision_output = self.vision_tower(pixel_values) + x = vision_output.hidden_states else: x = pixel_values batch_size = pixel_values.shape[0] diff --git a/mlx_vlm/models/florence2/vision.py b/mlx_vlm/models/florence2/vision.py index e32425356..b1f0a96c7 100644 --- a/mlx_vlm/models/florence2/vision.py +++ b/mlx_vlm/models/florence2/vision.py @@ -7,6 +7,8 @@ import mlx.core as mx import mlx.nn as nn +from ..base import VisionModelOutput + @dataclass class VisionConfig: @@ -557,7 +559,7 @@ def __call__(self, x): for blk in blks: x, input_size = blk(x, input_size) - return x + return VisionModelOutput(hidden_states=x) @staticmethod def sanitize(weights): diff --git a/mlx_vlm/models/idefics2/idefics2.py b/mlx_vlm/models/idefics2/idefics2.py index 1f7426891..8dadff60d 100644 --- a/mlx_vlm/models/idefics2/idefics2.py +++ b/mlx_vlm/models/idefics2/idefics2.py @@ -12,6 +12,7 @@ from huggingface_hub import snapshot_download from transformers import AutoConfig +from ..base import BaseModel from .language import LanguageModel, TextConfig from .vision import VisionConfig, VisionModel @@ -221,10 +222,10 @@ def get_input_embeddings( inputs_embeds = self.language_model.embed_tokens(input_ids) - pooler_output, embeddings, hidden_state = self.vision_model( + vision_output = self.vision_model( pixel_values[0].transpose(0, 2, 3, 1), output_hidden_states=True ) - image_features = pooler_output.astype(pixel_values.dtype) + image_features = vision_output.pooler_output.astype(pixel_values.dtype) image_features = self.connector(image_features, mask=None) final_inputs_embeds = self._prepare_inputs_for_multimodal( diff --git a/mlx_vlm/models/idefics2/vision.py b/mlx_vlm/models/idefics2/vision.py index 4764a2f2c..283e53562 100644 --- a/mlx_vlm/models/idefics2/vision.py +++ b/mlx_vlm/models/idefics2/vision.py @@ -6,6 +6,8 @@ import mlx.nn as nn import numpy as np +from ..base import VisionModelOutput + @dataclass class VisionConfig: @@ -247,7 +249,9 @@ def __call__( pooler_output = self.post_layernorm(encoder_outputs[0]) - return pooler_output, x, encoder_outputs[-1] + return VisionModelOutput( + pooler_output=pooler_output, encoder_states=encoder_outputs + ) def sanitize(self, weights): sanitized_weights = {} diff --git a/mlx_vlm/models/idefics3/idefics3.py b/mlx_vlm/models/idefics3/idefics3.py index b83c48cb2..208b0bf59 100644 --- a/mlx_vlm/models/idefics3/idefics3.py +++ b/mlx_vlm/models/idefics3/idefics3.py @@ -103,9 +103,10 @@ def get_input_embeddings( inputs_embeds = self.language_model.embed_tokens(input_ids) - pooler_output, embeddings, hidden_state = self.vision_model( + vision_output = self.vision_model( pixel_values[0].transpose(0, 2, 3, 1), output_hidden_states=True ) + pooler_output = vision_output.pooler_output image_features = pooler_output.astype(pixel_values.dtype) image_features = self.connector(image_features) diff --git a/mlx_vlm/models/idefics3/vision.py b/mlx_vlm/models/idefics3/vision.py index a87f1ddcb..5bc5da498 100644 --- a/mlx_vlm/models/idefics3/vision.py +++ b/mlx_vlm/models/idefics3/vision.py @@ -6,6 +6,8 @@ import mlx.nn as nn import numpy as np +from ..base import VisionModelOutput + @dataclass class VisionConfig: @@ -210,7 +212,9 @@ def __call__( x=x, output_hidden_states=output_hidden_states, mask=None ) pooler_output = self.post_layernorm(encoder_outputs[0]) - return pooler_output, x, encoder_outputs[-1] + return VisionModelOutput( + pooler_output=pooler_output, encoder_states=encoder_outputs + ) def sanitize(self, weights): sanitized_weights = {} diff --git a/mlx_vlm/models/llava/llava.py b/mlx_vlm/models/llava/llava.py index 677a44c4b..4b125629c 100644 --- a/mlx_vlm/models/llava/llava.py +++ b/mlx_vlm/models/llava/llava.py @@ -77,9 +77,10 @@ def get_input_embeddings( inputs_embeds = self.language_model.model.embed_tokens(input_ids) # Get the ouptut hidden states from the vision model - *_, hidden_states = self.vision_tower( + vision_output = self.vision_tower( pixel_values.transpose(0, 2, 3, 1), output_hidden_states=True ) + hidden_states = vision_output.encoder_states # Select the hidden states from the desired layer selected_image_feature = hidden_states[self.vision_feature_layer] diff --git a/mlx_vlm/models/llava/vision.py b/mlx_vlm/models/llava/vision.py index c0efe667d..a74d4247b 100644 --- a/mlx_vlm/models/llava/vision.py +++ b/mlx_vlm/models/llava/vision.py @@ -7,6 +7,8 @@ import mlx.nn as nn import numpy as np +from ..base import VisionModelOutput + @dataclass class VisionConfig: @@ -207,24 +209,28 @@ def __call__( self, x: mx.array, output_hidden_states: Optional[bool] = None, - output_attn: Optional[bool] = None, + output_attentions: Optional[bool] = None, ) -> mx.array: x = self.embeddings(x) if self.config.model_type == "clip_vision_model": x = self.pre_layrnorm(x) encoder_states = (x,) if output_hidden_states else None - all_attns = () if output_attn else None + all_attentions = () if output_attentions else None for l in self.encoder.layers: x, attn = l(x, mask=None) if output_hidden_states: encoder_states = encoder_states + (x,) - if output_attn: - all_attns = all_attns + (attn,) + if output_attentions: + all_attentions = all_attentions + (attn,) pooler_output = self.post_layernorm(x[:, 0, :]) - return pooler_output, x, encoder_states + return VisionModelOutput( + pooler_output=pooler_output, + encoder_states=encoder_states, + attentions=all_attentions, + ) class VisionModel(nn.Module): diff --git a/mlx_vlm/models/llava_bunny/llava_bunny.py b/mlx_vlm/models/llava_bunny/llava_bunny.py index b90db8c25..b12054517 100644 --- a/mlx_vlm/models/llava_bunny/llava_bunny.py +++ b/mlx_vlm/models/llava_bunny/llava_bunny.py @@ -145,11 +145,11 @@ def get_input_embeddings( inputs_embeds = self.language_model.model.embed_tokens(input_ids) - *_, hidden_state = self.vision_tower( + vision_output = self.vision_tower( pixel_values.transpose(0, 2, 3, 1), output_hidden_states=True ) - image_features = hidden_state[-1].astype(pixel_values.dtype) + image_features = vision_output.encoder_states[-1].astype(pixel_values.dtype) assert image_features.shape[-2] == 729 image_features = self.mm_projector(image_features) diff --git a/mlx_vlm/models/llava_bunny/vision.py b/mlx_vlm/models/llava_bunny/vision.py index 294b1c88b..5ee277212 100644 --- a/mlx_vlm/models/llava_bunny/vision.py +++ b/mlx_vlm/models/llava_bunny/vision.py @@ -6,6 +6,8 @@ import mlx.nn as nn import numpy as np +from ..base import VisionModelOutput + @dataclass class VisionConfig: @@ -225,23 +227,27 @@ def __call__( self, x: mx.array, output_hidden_states: Optional[bool] = None, - output_attn: Optional[bool] = None, + output_attentions: Optional[bool] = None, ) -> mx.array: x = self.embeddings(x) encoder_states = (x,) if output_hidden_states else None - all_attns = () if output_attn else None + all_attentions = () if output_attentions else None for l in self.encoder.layers: x, attn = l(x, mask=None) if output_hidden_states: encoder_states = encoder_states + (x,) - if output_attn: - all_attns = all_attns + (attn,) + if output_attentions: + all_attentions = all_attentions + (attn,) pooler_output = self.post_layernorm(x[:, 0, :]) pooler_output = self.head(pooler_output) - return pooler_output, x, encoder_states + return VisionModelOutput( + pooler_output=pooler_output, + encoder_states=encoder_states, + attentions=all_attentions, + ) class SigLipMultiheadAttentionPoolingHead(nn.Module): diff --git a/mlx_vlm/models/llava_next/llava_next.py b/mlx_vlm/models/llava_next/llava_next.py index 1505e7473..9eeeb5995 100644 --- a/mlx_vlm/models/llava_next/llava_next.py +++ b/mlx_vlm/models/llava_next/llava_next.py @@ -83,11 +83,13 @@ def get_input_embeddings( inputs_embeds = self.language_model.model.embed_tokens(input_ids) # Get the ouptut hidden states from the vision model - *_, hidden_states, all_attns = self.vision_tower( + vision_output = self.vision_tower( pixel_values[0].transpose(0, 2, 3, 1), output_hidden_states=True, - output_attn=True, + output_attentions=True, ) + hidden_states = vision_output.encoder_states + all_attns = vision_output.attentions # Select the hidden states from the desired layer selected_image_feature = hidden_states[self.vision_feature_layer] diff --git a/mlx_vlm/models/llava_next/vision.py b/mlx_vlm/models/llava_next/vision.py index 3443c146f..fd6e70c90 100644 --- a/mlx_vlm/models/llava_next/vision.py +++ b/mlx_vlm/models/llava_next/vision.py @@ -7,6 +7,8 @@ import mlx.nn as nn import numpy as np +from ..base import VisionModelOutput + @dataclass class VisionConfig: @@ -192,23 +194,27 @@ def __call__( self, x: mx.array, output_hidden_states: Optional[bool] = None, - output_attn: bool = False, + output_attentions: bool = False, ) -> mx.array: x = self.embeddings(x) x = self.pre_layrnorm(x) encoder_states = (x,) if output_hidden_states else None - all_attns = () if output_attn else None + all_attentions = () if output_attentions else None for l in self.encoder.layers: x, attn = l(x, mask=None) if output_hidden_states: encoder_states = encoder_states + (x,) - if output_attn: - all_attns = all_attns + (attn,) + if output_attentions: + all_attentions = all_attentions + (attn,) pooler_output = self.post_layernorm(x[:, 0, :]) - return pooler_output, x, encoder_states, all_attns + return VisionModelOutput( + pooler_output=pooler_output, + encoder_states=encoder_states, + attentions=all_attentions, + ) class VisionModel(nn.Module): @@ -225,9 +231,9 @@ def __call__( self, x: mx.array, output_hidden_states: Optional[bool] = None, - output_attn: bool = False, + output_attentions: bool = False, ) -> mx.array: - return self.vision_model(x, output_hidden_states, output_attn) + return self.vision_model(x, output_hidden_states, output_attentions) def sanitize(self, weights): sanitized_weights = {} diff --git a/mlx_vlm/models/mllama/mllama.py b/mlx_vlm/models/mllama/mllama.py index b51e18563..ec11b369b 100644 --- a/mlx_vlm/models/mllama/mllama.py +++ b/mlx_vlm/models/mllama/mllama.py @@ -70,12 +70,12 @@ def __call__( "`aspect_ratio_ids` must be provided if `pixel_values` is provided" ) - vision_outputs = self.vision_tower( + vision_output = self.vision_tower( pixel_values=pixel_values, aspect_ratio_ids=aspect_ratio_ids, aspect_ratio_mask=aspect_ratio_mask, ) - cross_attention_states = vision_outputs[0] + cross_attention_states = vision_output.hidden_states[0] cross_attention_states = self.multi_modal_projector( cross_attention_states diff --git a/mlx_vlm/models/mllama/vision.py b/mlx_vlm/models/mllama/vision.py index 4fa25829e..2d07aa782 100644 --- a/mlx_vlm/models/mllama/vision.py +++ b/mlx_vlm/models/mllama/vision.py @@ -5,6 +5,8 @@ import mlx.core as mx import mlx.nn as nn +from ..base import VisionModelOutput + @dataclass class VisionConfig: @@ -440,7 +442,7 @@ def __call__( [hidden_state, intermediate_hidden_states], axis=-1 ) - return hidden_state + return VisionModelOutput(hidden_states=hidden_state) @staticmethod def sanitize(weights): diff --git a/mlx_vlm/models/molmo/molmo.py b/mlx_vlm/models/molmo/molmo.py index fd2ec1fa0..c14407c58 100644 --- a/mlx_vlm/models/molmo/molmo.py +++ b/mlx_vlm/models/molmo/molmo.py @@ -80,9 +80,11 @@ def __call__( else None ) - image_features, cls_embed, all_attns = self.vision_tower( + vision_output = self.vision_tower( pixel_values, image_masks, output_attentions=True ) + image_features = vision_output.hidden_states + all_attns = vision_output.attentions # Insert image features into the input embeddings num_image, num_patch = image_features.shape[1:3] @@ -101,9 +103,7 @@ def __call__( image_input_idx = image_input_idx.reshape(batch_size, num_image * num_patch) all_attns = all_attns.reshape(batch_size, num_image * num_patch, -1) - print("all_attns", all_attns.shape) - print("image_features", image_features.shape) - print("image_input_idx", image_input_idx.shape) + if all_attns is not None: attn = all_attns[None, ...] vision_filter_ratio = kwargs.get("vision_filter_ratio", 1.0) diff --git a/mlx_vlm/models/molmo/vision.py b/mlx_vlm/models/molmo/vision.py index d84628dca..022db433f 100644 --- a/mlx_vlm/models/molmo/vision.py +++ b/mlx_vlm/models/molmo/vision.py @@ -5,6 +5,8 @@ import mlx.core as mx import mlx.nn as nn +from ..base import VisionModelOutput + @dataclass class VisionConfig: @@ -483,7 +485,4 @@ def __call__( # MLP layer to map the feature image_features = self.image_projector(image_features) - if output_attentions: - return image_features, cls_embed, all_attns - else: - return image_features, cls_embed + return VisionModelOutput(hidden_states=image_features, attentions=all_attns) diff --git a/mlx_vlm/models/multi_modality/vision.py b/mlx_vlm/models/multi_modality/vision.py index ae6d4f104..f39694dea 100644 --- a/mlx_vlm/models/multi_modality/vision.py +++ b/mlx_vlm/models/multi_modality/vision.py @@ -10,6 +10,7 @@ import numpy as np from scipy.ndimage import zoom +from ..base import VisionModelOutput from .sam import SAMEncoder @@ -325,7 +326,9 @@ def __call__( if not self.ignore_head: pooler_output = self.attn_pool(pooler_output) - return pooler_output, x, encoder_states + return VisionModelOutput( + pooler_output=pooler_output, hidden_states=x, encoder_states=encoder_states + ) class HybridVisionModel(nn.Module): @@ -346,7 +349,7 @@ def __call__(self, x: mx.array) -> mx.array: if self.resolution == "high": return self.vision_tower(x) else: - return self.vision_tower(x)[0] + return self.vision_tower(x).pooler_output def resize_image(image, size, antialias=True): diff --git a/mlx_vlm/models/paligemma/paligemma.py b/mlx_vlm/models/paligemma/paligemma.py index 8223aa286..36249c27d 100644 --- a/mlx_vlm/models/paligemma/paligemma.py +++ b/mlx_vlm/models/paligemma/paligemma.py @@ -72,12 +72,15 @@ def get_input_embeddings( inputs_embeds = self.language_model.model.embed_tokens(input_ids) - hidden_state, _, _, all_attns = self.vision_tower( + vision_output = self.vision_tower( pixel_values.transpose(0, 2, 3, 1).astype(inputs_embeds.dtype), output_hidden_states=True, output_attentions=True, ) + hidden_state = vision_output.pooler_output + all_attns = vision_output.attentions + image_features = hidden_state[None, :].astype(pixel_values.dtype) if all_attns: diff --git a/mlx_vlm/models/paligemma/vision.py b/mlx_vlm/models/paligemma/vision.py index 55d9af8bc..3817d8759 100644 --- a/mlx_vlm/models/paligemma/vision.py +++ b/mlx_vlm/models/paligemma/vision.py @@ -6,6 +6,8 @@ import mlx.nn as nn import numpy as np +from ..base import VisionModelOutput + @dataclass class VisionConfig: @@ -203,13 +205,16 @@ def __call__( ) -> mx.array: x = self.embeddings(x) - h, encoder_states, all_attns = self.encoder( + h, encoder_states, all_attentions = self.encoder( x=x, output_hidden_states=output_hidden_states, mask=None ) pooler_output = self.post_layernorm(h) - - return pooler_output, x, encoder_states, all_attns + return VisionModelOutput( + pooler_output=pooler_output, + encoder_states=encoder_states, + attentions=all_attentions, + ) class VisionModel(nn.Module): diff --git a/mlx_vlm/models/phi3_v/vision.py b/mlx_vlm/models/phi3_v/vision.py index e41a013ba..ee9b85912 100644 --- a/mlx_vlm/models/phi3_v/vision.py +++ b/mlx_vlm/models/phi3_v/vision.py @@ -8,6 +8,8 @@ import mlx.nn as nn import numpy as np +from ..base import VisionModelOutput + @dataclass class VisionConfig: @@ -210,7 +212,9 @@ def __call__( encoder_states = encoder_states + (x,) pooler_output = self.post_layernorm(x[:, 0, :]) - return pooler_output, x, encoder_states + return VisionModelOutput( + pooler_output=pooler_output, hidden_states=x, encoder_states=encoder_states + ) class ClipVModel(nn.Module): @@ -264,7 +268,8 @@ def __call__( img_sizes = (img_sizes // 336).tolist() img_features = self.img_processor.vision_model( img_embeds.reshape(-1, *img_embeds.shape[2:]).transpose(0, 2, 3, 1), True - )[-1][-2][:, 1:] + ) + image_features = img_features.encoder_states[-1][-2][:, 1:] img_features = img_features.reshape(B, -1, *img_features.shape[1:]) C, H = self.image_dim_out, int(img_features.shape[2] ** 0.5) output_imgs, output_len = [], [] diff --git a/mlx_vlm/models/pixtral/pixtral.py b/mlx_vlm/models/pixtral/pixtral.py index 9e1ff7cd9..20e2cce59 100644 --- a/mlx_vlm/models/pixtral/pixtral.py +++ b/mlx_vlm/models/pixtral/pixtral.py @@ -88,11 +88,15 @@ def get_input_embeddings( # Pass pixel_values as list of images, as each image is individually run through conv2d and position encoding # Reference code from transformers: https://github.com/huggingface/transformers/blob/main/src/transformers/models/pixtral/modeling_pixtral.py#L479C9-L479C21 # and mistral_inference: https://github.com/mistralai/mistral-inference/blob/main/src/mistral_inference/vision_encoder.py#L85 - *_, hidden_states, all_attns = self.vision_tower( + vision_output = self.vision_tower( pixel_values.transpose(0, 2, 3, 1), output_hidden_states=True, output_attentions=True, ) + + hidden_states = vision_output.encoder_states + all_attns = vision_output.attentions + # Select the hidden states from the desired layer selected_image_feature = hidden_states[self.vision_feature_layer] diff --git a/mlx_vlm/models/pixtral/vision.py b/mlx_vlm/models/pixtral/vision.py index c54d86ec7..83174567c 100644 --- a/mlx_vlm/models/pixtral/vision.py +++ b/mlx_vlm/models/pixtral/vision.py @@ -5,6 +5,8 @@ import mlx.core as mx import mlx.nn as nn +from ..base import VisionModelOutput + @dataclass class VisionConfig: @@ -285,7 +287,11 @@ def __call__( if output_attentions: all_attns = all_attns + (attn,) - return patch_embeds, encoder_states, all_attns + return VisionModelOutput( + hidden_states=patch_embeds, + encoder_states=encoder_states, + attentions=all_attns, + ) class VisionModel(nn.Module): diff --git a/mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py b/mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py index 15b61f72a..aa43f1bfc 100644 --- a/mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +++ b/mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py @@ -72,10 +72,16 @@ def get_input_embeddings( inputs_embeds = self.language_model.model.embed_tokens(input_ids) # Get the ouptut hidden states from the vision model - hidden_states, all_attns = self.vision_tower( - pixel_values, image_grid_thw, output_hidden_states=False, output_attn=True + vision_output = self.vision_tower( + pixel_values, + image_grid_thw, + output_hidden_states=False, + output_attentions=True, ) + hidden_states = vision_output.hidden_states + all_attns = vision_output.attentions + if hidden_states.ndim == 2: hidden_states = hidden_states[None, :, :] diff --git a/mlx_vlm/models/qwen2_5_vl/vision.py b/mlx_vlm/models/qwen2_5_vl/vision.py index 15d08c203..d8fdb4f46 100644 --- a/mlx_vlm/models/qwen2_5_vl/vision.py +++ b/mlx_vlm/models/qwen2_5_vl/vision.py @@ -6,6 +6,8 @@ import mlx.nn as nn import numpy as np +from ..base import VisionModelOutput + @dataclass class VisionConfig: @@ -363,7 +365,7 @@ def __call__( hidden_states: mx.array, grid_thw: mx.array, output_hidden_states: Optional[bool] = None, - output_attn: Optional[bool] = None, + output_attentions: Optional[bool] = None, ) -> mx.array: hidden_states = self.patch_embed(hidden_states) @@ -409,7 +411,7 @@ def __call__( cu_seqlens = mx.pad(cu_seqlens, (1, 0), mode="constant", constant_values=0) encoder_states = (hidden_states,) if output_hidden_states else None - all_attns = () if output_attn else None + all_attentions = () if output_attentions else None for layer_num, blk in enumerate(self.blocks): if layer_num in self.fullatt_block_indexes: @@ -424,13 +426,13 @@ def __call__( if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if output_attn: - all_attns = all_attns + (attn,) + if output_attentions: + all_attentions = all_attentions + (attn,) hidden_states = self.merger(hidden_states) reverse_indices = mx.argsort(window_index, axis=0) hidden_states = hidden_states[reverse_indices, :] - return hidden_states, all_attns + return VisionModelOutput(hidden_states=hidden_states, attentions=all_attentions) def sanitize(self, weights): sanitized_weights = {} diff --git a/mlx_vlm/models/qwen2_vl/qwen2_vl.py b/mlx_vlm/models/qwen2_vl/qwen2_vl.py index 7b27049d0..eb1eeb7c7 100644 --- a/mlx_vlm/models/qwen2_vl/qwen2_vl.py +++ b/mlx_vlm/models/qwen2_vl/qwen2_vl.py @@ -69,15 +69,21 @@ def get_input_embeddings( inputs_embeds = self.language_model.model.embed_tokens(input_ids) # Get the ouptut hidden states from the vision model - hidden_states, all_attns = self.vision_tower( - pixel_values, image_grid_thw, output_hidden_states=False, output_attn=True + vision_output = self.vision_tower( + pixel_values, + image_grid_thw, + output_hidden_states=False, + output_attentions=True, ) + hidden_states = vision_output.hidden_states + all_attentions = vision_output.attentions + if hidden_states.ndim == 2: hidden_states = hidden_states[None, :, :] - if all_attns: - attn = all_attns[-1] + if all_attentions: + attn = all_attentions[-1] vision_filter_ratio = kwargs.get("vision_filter_ratio", 1.0) vision_merge_ratio = kwargs.get("vision_merge_ratio", 1.0) hidden_states = self.filter_topk_vision_tokens( diff --git a/mlx_vlm/models/qwen2_vl/vision.py b/mlx_vlm/models/qwen2_vl/vision.py index b5a26a0b0..77079735f 100644 --- a/mlx_vlm/models/qwen2_vl/vision.py +++ b/mlx_vlm/models/qwen2_vl/vision.py @@ -5,6 +5,8 @@ import mlx.core as mx import mlx.nn as nn +from ..base import VisionModelOutput + @dataclass class VisionConfig: @@ -287,7 +289,7 @@ def __call__( hidden_states: mx.array, grid_thw: mx.array, output_hidden_states: Optional[bool] = None, - output_attn: Optional[bool] = None, + output_attentions: Optional[bool] = None, ) -> mx.array: hidden_states = self.patch_embed(hidden_states) @@ -309,7 +311,7 @@ def __call__( cu_seqlens = mx.pad(cu_seqlens, (1, 0), mode="constant", constant_values=0) encoder_states = (hidden_states,) if output_hidden_states else None - all_attns = () if output_attn else None + all_attentions = () if output_attentions else None for blk in self.blocks: hidden_states, attn = blk( @@ -317,10 +319,12 @@ def __call__( ) if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if output_attn: - all_attns = all_attns + (attn,) + if output_attentions: + all_attentions = all_attentions + (attn,) - return self.merger(hidden_states), all_attns + return VisionModelOutput( + hidden_states=self.merger(hidden_states), attentions=all_attentions + ) def sanitize(self, weights): sanitized_weights = {} diff --git a/mlx_vlm/tests/test_models.py b/mlx_vlm/tests/test_models.py index 21ebebe54..1e4a6d1af 100644 --- a/mlx_vlm/tests/test_models.py +++ b/mlx_vlm/tests/test_models.py @@ -5,6 +5,15 @@ from mlx.utils import tree_map +def get_hidden_states(vision_output): + if vision_output.hidden_states is not None: + return vision_output.hidden_states + elif vision_output.encoder_states is not None: + return vision_output.encoder_states[-1] + else: + return vision_output.pooler_output + + class TestModels(unittest.TestCase): def language_test_runner(self, model, model_type, vocab_size, num_layers): @@ -61,6 +70,7 @@ def vision_test_runner( self.assertEqual(vision_tower.model_type, model_type) batch_size = 1 + all_attentions = None if model_type == "qwen2_5_vl": input_tensor = mx.random.uniform(shape=(image_size[0], image_size[1])) else: @@ -80,17 +90,35 @@ def vision_test_runner( "output_hidden_states" in inspect.signature(vision_tower.__call__).parameters ): - hidden_states = vision_tower( + vision_output = vision_tower( input_tensor, output_hidden_states=True, **kwargs ) + + hidden_states = get_hidden_states(vision_output) + elif "output_attentions" in inspect.signature(vision_tower.__call__).parameters: + vision_output = vision_tower(input_tensor, output_attentions=True, **kwargs) + hidden_states = get_hidden_states(vision_output) + all_attentions = vision_output.attentions + else: - hidden_states = vision_tower(input_tensor, **kwargs) + vision_output = vision_tower(input_tensor, **kwargs) + hidden_states = get_hidden_states(vision_output) + + print("hidden_states", len(hidden_states)) # Check vision hidden feature layer's shape matches the expected hidden size self.assertEqual( hidden_states[vision_feature_layer].shape[-1], vision_hidden_size ) + if "output_attentions" in inspect.signature(vision_tower.__call__).parameters: + if all_attentions is not None: + config = vision_tower.config.__dict__ + if len(all_attentions) > 1: # TODO: Fix this test for Molmo + self.assertEqual( + len(all_attentions), config.get("num_hidden_layers", None) + ) + def test_llava_bunny(self): from mlx_vlm.models import llava_bunny diff --git a/mlx_vlm/tests/test_smoke.py b/mlx_vlm/tests/test_smoke.py index 70877a075..729981a3e 100644 --- a/mlx_vlm/tests/test_smoke.py +++ b/mlx_vlm/tests/test_smoke.py @@ -52,6 +52,12 @@ def parse_args(): "--max-tokens", type=int, default=100, help="Maximum tokens to generate" ) parser.add_argument("--resize-shape", type=int, default=None, help="Resize shape") + parser.add_argument( + "--vision-filter-ratio", type=float, default=0.8, help="Vision filter ratio" + ) + parser.add_argument( + "--vision-merge-ratio", type=float, default=0.8, help="Vision merge ratio" + ) return parser.parse_args() From 7de2f1d25a12ce59ce2c3a0db11a4ebe9a23080b Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Mon, 3 Mar 2025 14:34:41 +0100 Subject: [PATCH 17/19] fix video generate --- mlx_vlm/models/qwen2_vl/qwen2_vl.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/mlx_vlm/models/qwen2_vl/qwen2_vl.py b/mlx_vlm/models/qwen2_vl/qwen2_vl.py index eb1eeb7c7..0441435dd 100644 --- a/mlx_vlm/models/qwen2_vl/qwen2_vl.py +++ b/mlx_vlm/models/qwen2_vl/qwen2_vl.py @@ -86,10 +86,10 @@ def get_input_embeddings( attn = all_attentions[-1] vision_filter_ratio = kwargs.get("vision_filter_ratio", 1.0) vision_merge_ratio = kwargs.get("vision_merge_ratio", 1.0) - hidden_states = self.filter_topk_vision_tokens( + hidden_states, _ = self.filter_topk_vision_tokens( hidden_states, attn, vision_filter_ratio ) - hidden_states = self.merge_similar_vision_tokens( + hidden_states, _ = self.merge_similar_vision_tokens( hidden_states, vision_merge_ratio ) @@ -129,11 +129,14 @@ def __call__( ): image_grid_thw = kwargs.pop("image_grid_thw", None) - if image_grid_thw is not None: - image_grid_thw = mx.array(image_grid_thw) + video_grid_thw = kwargs.pop("video_grid_thw", None) + grid_thw = image_grid_thw if image_grid_thw is not None else video_grid_thw + + if grid_thw is not None: + grid_thw = mx.array(grid_thw) input_embddings = self.get_input_embeddings( - input_ids, pixel_values, image_grid_thw, **kwargs + input_ids, pixel_values, grid_thw, **kwargs ) logits = self.language_model(None, cache=cache, inputs_embeds=input_embddings) From 16a3d1a07a19479400d6629aa5413a6bd7c1ddfe Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Mon, 3 Mar 2025 14:35:12 +0100 Subject: [PATCH 18/19] add filtering and merging to video generate --- mlx_vlm/video_generate.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/mlx_vlm/video_generate.py b/mlx_vlm/video_generate.py index 98dc200bc..ee47e74c6 100644 --- a/mlx_vlm/video_generate.py +++ b/mlx_vlm/video_generate.py @@ -454,7 +454,20 @@ def main(): help="Select the model to use", ) parser.add_argument("--verbose", action="store_false", help="Print verbose output") - + parser.add_argument( + "--vision-merge-ratio", + type=float, + default=1.0, + help="Ratio of vision tokens to keep during merging similar tokens (between 0.1 and 1.0).", + choices=[x / 10 for x in range(1, 11)], + ) + parser.add_argument( + "--vision-filter-ratio", + type=float, + default=1.0, + help="Ratio of vision tokens to keep during filtering topk tokens (between 0.1 and 1.0).", + choices=[x / 10 for x in range(1, 11)], + ) args = parser.parse_args() print(f"\033[32mLoading model:\033[0m {args.model}") @@ -529,7 +542,6 @@ def main(): kwargs["video_grid_thw"] = mx.array(inputs["video_grid_thw"]) if inputs.get("image_grid_thw", None) is not None: kwargs["image_grid_thw"] = mx.array(inputs["image_grid_thw"]) - else: if is_video_file(args.video): if len(args.video) > 1: @@ -585,6 +597,8 @@ def main(): kwargs["mask"] = mask kwargs["temperature"] = args.temperature kwargs["max_tokens"] = args.max_tokens + kwargs["vision_merge_ratio"] = args.vision_merge_ratio + kwargs["vision_filter_ratio"] = args.vision_filter_ratio response = generate( model, From ad4edbfd2ed7029480dee2851f8e15a26973e19c Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Tue, 4 Mar 2025 21:11:59 +0100 Subject: [PATCH 19/19] remove filtering and merging --- mlx_vlm/video_generate.py | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/mlx_vlm/video_generate.py b/mlx_vlm/video_generate.py index ee47e74c6..aa9be5cc6 100644 --- a/mlx_vlm/video_generate.py +++ b/mlx_vlm/video_generate.py @@ -454,20 +454,7 @@ def main(): help="Select the model to use", ) parser.add_argument("--verbose", action="store_false", help="Print verbose output") - parser.add_argument( - "--vision-merge-ratio", - type=float, - default=1.0, - help="Ratio of vision tokens to keep during merging similar tokens (between 0.1 and 1.0).", - choices=[x / 10 for x in range(1, 11)], - ) - parser.add_argument( - "--vision-filter-ratio", - type=float, - default=1.0, - help="Ratio of vision tokens to keep during filtering topk tokens (between 0.1 and 1.0).", - choices=[x / 10 for x in range(1, 11)], - ) + args = parser.parse_args() print(f"\033[32mLoading model:\033[0m {args.model}") @@ -597,8 +584,6 @@ def main(): kwargs["mask"] = mask kwargs["temperature"] = args.temperature kwargs["max_tokens"] = args.max_tokens - kwargs["vision_merge_ratio"] = args.vision_merge_ratio - kwargs["vision_filter_ratio"] = args.vision_filter_ratio response = generate( model,