Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pip install mlx-vlm
Generate output from a model using the CLI:

```sh
python -m mlx_vlm.generate --model mlx-community/Qwen2-VL-2B-Instruct-4bit --max-tokens 100 --temp 0.0 --image http://images.cocodataset.org/val2017/000000039769.jpg
python -m mlx_vlm.generate --model mlx-community/Qwen2-VL-2B-Instruct-4bit --max-tokens 100 --temperature 0.0 --image http://images.cocodataset.org/val2017/000000039769.jpg
```

### Chat UI with Gradio
Expand Down
19 changes: 16 additions & 3 deletions mlx_vlm/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
DEFAULT_IMAGE = []
DEFAULT_PROMPT = "What are these?"
DEFAULT_MAX_TOKENS = 256
DEFAULT_TEMP = 0.5
DEFAULT_TEMPERATURE = 0.5
DEFAULT_TOP_P = 1.0
DEFAULT_SEED = 0

Expand Down Expand Up @@ -71,12 +71,25 @@ def parse_arguments():
parser.add_argument(
"--temperature",
type=float,
default=DEFAULT_TEMP,
default=DEFAULT_TEMPERATURE,
help="Temperature for sampling.",
)
parser.add_argument("--chat", action="store_true", help="Chat in multi-turn style.")
parser.add_argument("--verbose", action="store_false", help="Detailed output.")

parser.add_argument(
"--vision-merge-ratio",
type=float,
default=1.0,
help="Ratio of vision tokens to keep during merging similar tokens (between 0.1 and 1.0).",
choices=[x / 10 for x in range(1, 11)],
)
parser.add_argument(
"--vision-filter-ratio",
type=float,
default=1.0,
help="Ratio of vision tokens to keep during filtering topk tokens (between 0.1 and 1.0).",
choices=[x / 10 for x in range(1, 11)],
)
return parser.parse_args()


Expand Down
128 changes: 125 additions & 3 deletions mlx_vlm/models/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple

import mlx.core as mx
import mlx.nn as nn
from PIL import Image
from transformers.image_processing_utils import BaseImageProcessor as ImageProcessor
from transformers.image_processing_utils import get_size_dict
Expand Down Expand Up @@ -97,6 +98,10 @@ def update(self, keys, values):
self.keys[..., prev : self.offset, :] = keys
self.values[..., prev : self.offset, :] = values

@property
def state(self):
return self.keys, self.values


class SimpleKVCache:
"""A simple key-value cache for transformer attention layers.
Expand Down Expand Up @@ -148,15 +153,15 @@ def update(self, keys, values):

class RotatingKVCache:

def __init__(self, head_dim, n_kv_heads, max_size, keep=0, step=256):
def __init__(self, head_dim, n_kv_heads, max_size, keep=None, step=256):
self.n_kv_heads = n_kv_heads
if isinstance(head_dim, int):
self.k_head_dim = self.v_head_dim = head_dim
elif isinstance(head_dim, tuple) and len(head_dim) == 2:
self.k_head_dim, self.v_head_dim = head_dim
else:
raise ValueError("head_dim must be an int or a tuple of two ints")
self.keep = keep
self.keep = keep if keep is not None else step // 2
self.keys = None
self.values = None
self.offset = 0
Expand Down Expand Up @@ -271,3 +276,120 @@ class LanguageModelOutput:
logits: mx.array
cross_attention_states: Optional[List[mx.array]] = None
encoder_outputs: Optional[List[mx.array]] = None


@dataclass
class VisionModelOutput:
hidden_states: Optional[mx.array] = None
encoder_states: Optional[List[mx.array]] = None
attentions: Optional[List[mx.array]] = None
pooler_output: Optional[mx.array] = None


class BaseModel(nn.Module):
def __init__(self):
super().__init__()
self.vision_tower = None
self.language_model = None

def filter_topk_vision_tokens(
self, image_feature, attn, vision_filter_ratio=None
) -> Tuple[mx.array, mx.array]:
batch_size, seq_len = image_feature.shape[:2]
k_tokens = (
int(image_feature.shape[1] * vision_filter_ratio)
if vision_filter_ratio is not None
else None
) # keep 25% of the visual tokens

if k_tokens is None or k_tokens == seq_len:
return image_feature, None

cls_idx = 0 # self.config.image_token_index

attn_rec = mx.sum(attn[:, :, cls_idx + 1 :, cls_idx], axis=1)

topk_idx = mx.argsort(attn_rec, axis=1)[:, -k_tokens:]

# Create CLS token indices array
# Shape: (B, 1)
cls_indices = mx.full((batch_size, 1), cls_idx, dtype=mx.int32)

# Concat with CLS token index
# Add 1 to account for the offset after CLS token
dominant_idx = mx.concatenate([cls_indices, topk_idx + cls_idx + 1], axis=1)

image_feature = mx.take(image_feature, dominant_idx, axis=1)[0]
return image_feature, dominant_idx

def merge_similar_vision_tokens(
self, image_feature, vision_merge_ratio, merge_rate=0.4
) -> Tuple[mx.array, mx.array]:
# Skip CLS token (first token)
tokens = image_feature[:, 1:]
batch_size, num_tokens, hidden_dim = tokens.shape

# Calculate target number of tokens
target_tokens = max(1, int(num_tokens * vision_merge_ratio))

if num_tokens == target_tokens:
return image_feature, None

# Create a mask of the same shape as tokens, initialized to True
mask = mx.ones((batch_size, num_tokens))

while num_tokens > target_tokens:
# Calculate similarities between adjacent tokens
tokens_a = tokens[:, :-1] # all except last
tokens_b = tokens[:, 1:] # all except first

# Calculate cosine similarity
a_norm = mx.sqrt(mx.sum(tokens_a * tokens_a, axis=-1, keepdims=True))
b_norm = mx.sqrt(mx.sum(tokens_b * tokens_b, axis=-1, keepdims=True))
similarities = mx.sum(tokens_a * tokens_b, axis=-1)
similarities = similarities / (a_norm.squeeze(-1) * b_norm.squeeze(-1))

# Sort similarities and get indices of pairs to merge
# We'll merge about 50% of remaining excess tokens in each iteration
num_to_merge = max(1, int((num_tokens - target_tokens) * merge_rate))
merge_indices = mx.argsort(similarities, axis=-1)[:, -num_to_merge:]

# Create a list to track which indices to merge
to_merge = set(merge_indices[0].tolist())

# Merge selected pairs
new_tokens = []
new_mask = []
i = 0
while i < num_tokens:
if i < num_tokens - 1 and i in to_merge:
# Merge this token with the next one
merged = (tokens[:, i : i + 1] + tokens[:, i + 1 : i + 2]) / 2
new_tokens.append(merged)
new_mask.append(mask[:, i : i + 1]) # Keep mask from first token
i += 2
elif i > 0 and (i - 1) in to_merge:
# Skip this token as it was merged in the previous step
i += 1
else:
# Keep this token as is
new_tokens.append(tokens[:, i : i + 1])
new_mask.append(mask[:, i : i + 1])
i += 1

# Update tokens and mask
tokens = mx.concatenate(new_tokens, axis=1)
mask = mx.concatenate(new_mask, axis=1)
num_tokens = tokens.shape[1]

# Add back CLS token
cls_mask = mx.ones((batch_size, 1), dtype=mx.bool_)
return mx.concatenate([image_feature[:, :1], tokens], axis=1), mx.concatenate(
[cls_mask, mask], axis=1
)

def merge_vision_patches(self, image_feature, vision_merge_ratio, merge_rate=0.4):
"""
Merge vision patches based on the vision_merge_ratio and merge_rate.
"""
pass
7 changes: 4 additions & 3 deletions mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
from transformers.image_utils import to_numpy_array

from ..base import expand2square
from ..base import BaseModel, expand2square
from .language import LanguageModel, TextConfig
from .processing_deepsek_vl_v2 import DeepseekVLV2Processor
from .vision import VisionConfig, VisionModel
Expand Down Expand Up @@ -194,7 +194,7 @@ def __call__(self, x):
return x


class Model(nn.Module):
class Model(BaseModel):
def __init__(self, config: ModelConfig):
super().__init__()
self.config = config
Expand Down Expand Up @@ -409,9 +409,10 @@ def get_input_embeddings(
input_embeds = self.language_model.model.embed_tokens(input_ids)

# Get the ouptut hidden states from the vision model
hidden_states, *_ = self.vision(
vision_output = self.vision(
total_tiles.transpose(0, 2, 3, 1), output_hidden_states=True
)
hidden_states = vision_output.encoder_states

# Pass image features through the multi-modal projector
image_features = self.projector(hidden_states)
Expand Down
6 changes: 5 additions & 1 deletion mlx_vlm/models/deepseek_vl_v2/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import mlx.nn as nn
import numpy as np

from ..base import VisionModelOutput


@dataclass
class VisionConfig:
Expand Down Expand Up @@ -305,7 +307,9 @@ def __call__(

if not self.ignore_head:
pooler_output = self.attn_pool(pooler_output)
return pooler_output, x, encoder_states
return VisionModelOutput(
pooler_output=pooler_output, encoder_states=encoder_states
)


class VisionModel(nn.Module):
Expand Down
6 changes: 4 additions & 2 deletions mlx_vlm/models/florence2/florence2.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from huggingface_hub import snapshot_download
from mlx.utils import tree_map

from ..base import BaseModel
from .language import LanguageModel, TextConfig
from .vision import VisionConfig, VisionModel

Expand Down Expand Up @@ -156,7 +157,7 @@ def __call__(self, seq_embeds: mx.array) -> mx.array:
return pos_embeds


class Model(nn.Module):
class Model(BaseModel):
"""Florence-2 model for conditional generation."""

def __init__(self, config: ModelConfig):
Expand Down Expand Up @@ -207,7 +208,8 @@ def _encode_image(self, pixel_values, extract_features=True):
# Get vision features
if extract_features:
batch_size, C, H, W = pixel_values.shape
x = self.vision_tower(pixel_values)
vision_output = self.vision_tower(pixel_values)
x = vision_output.hidden_states
else:
x = pixel_values
batch_size = pixel_values.shape[0]
Expand Down
4 changes: 3 additions & 1 deletion mlx_vlm/models/florence2/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import mlx.core as mx
import mlx.nn as nn

from ..base import VisionModelOutput


@dataclass
class VisionConfig:
Expand Down Expand Up @@ -557,7 +559,7 @@ def __call__(self, x):
for blk in blks:
x, input_size = blk(x, input_size)

return x
return VisionModelOutput(hidden_states=x)

@staticmethod
def sanitize(weights):
Expand Down
7 changes: 4 additions & 3 deletions mlx_vlm/models/idefics2/idefics2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from huggingface_hub import snapshot_download
from transformers import AutoConfig

from ..base import BaseModel
from .language import LanguageModel, TextConfig
from .vision import VisionConfig, VisionModel

Expand Down Expand Up @@ -200,7 +201,7 @@ def __call__(self, x: mx.array, mask=None) -> mx.array:
return x


class Model(nn.Module):
class Model(BaseModel):
def __init__(self, config: ModelConfig):
super().__init__()
self.model_type = config.model_type
Expand All @@ -221,10 +222,10 @@ def get_input_embeddings(

inputs_embeds = self.language_model.embed_tokens(input_ids)

pooler_output, embeddings, hidden_state = self.vision_model(
vision_output = self.vision_model(
pixel_values[0].transpose(0, 2, 3, 1), output_hidden_states=True
)
image_features = pooler_output.astype(pixel_values.dtype)
image_features = vision_output.pooler_output.astype(pixel_values.dtype)
image_features = self.connector(image_features, mask=None)

final_inputs_embeds = self._prepare_inputs_for_multimodal(
Expand Down
6 changes: 5 additions & 1 deletion mlx_vlm/models/idefics2/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import mlx.nn as nn
import numpy as np

from ..base import VisionModelOutput


@dataclass
class VisionConfig:
Expand Down Expand Up @@ -247,7 +249,9 @@ def __call__(

pooler_output = self.post_layernorm(encoder_outputs[0])

return pooler_output, x, encoder_outputs[-1]
return VisionModelOutput(
pooler_output=pooler_output, encoder_states=encoder_outputs
)

def sanitize(self, weights):
sanitized_weights = {}
Expand Down
6 changes: 4 additions & 2 deletions mlx_vlm/models/idefics3/idefics3.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from huggingface_hub import snapshot_download
from transformers import AutoConfig

from ..base import BaseModel
from .language import LanguageModel, TextConfig
from .vision import VisionConfig, VisionModel

Expand Down Expand Up @@ -81,7 +82,7 @@ def __call__(self, image_hidden_states):
return image_hidden_states


class Model(nn.Module):
class Model(BaseModel):
def __init__(self, config: ModelConfig):
super().__init__()
self.model_type = config.model_type
Expand All @@ -102,9 +103,10 @@ def get_input_embeddings(

inputs_embeds = self.language_model.embed_tokens(input_ids)

pooler_output, embeddings, hidden_state = self.vision_model(
vision_output = self.vision_model(
pixel_values[0].transpose(0, 2, 3, 1), output_hidden_states=True
)
pooler_output = vision_output.pooler_output

image_features = pooler_output.astype(pixel_values.dtype)
image_features = self.connector(image_features)
Expand Down
Loading