Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 5 additions & 15 deletions mlx_vlm/models/gemma4/language.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,18 +356,6 @@ def __call__(
return h


class ScaledLinear(nn.Module):
"""Linear layer with output scaling."""

def __init__(self, in_features: int, out_features: int, scalar: float):
super().__init__()
self.weight = mx.zeros((out_features, in_features))
self.scalar = scalar

def __call__(self, x: mx.array) -> mx.array:
return (x @ self.weight.T) * self.scalar


class Gemma4TextModel(nn.Module):
def __init__(self, config: TextConfig):
super().__init__()
Expand Down Expand Up @@ -422,17 +410,19 @@ def __init__(self, config: TextConfig):
)
self.embed_tokens_per_layer_scale = config.hidden_size_per_layer_input**0.5
self.per_layer_input_scale = 2.0**-0.5
self.per_layer_model_projection = ScaledLinear(
self.per_layer_projection_scale = config.hidden_size**-0.5
self.per_layer_model_projection = nn.Linear(
config.hidden_size,
config.num_hidden_layers * config.hidden_size_per_layer_input,
scalar=config.hidden_size**-0.5,
bias=False,
)
self.per_layer_projection_norm = RMSNormZeroShift(
config.hidden_size_per_layer_input, eps=config.rms_norm_eps
)
else:
self.embed_tokens_per_layer = None
self.per_layer_input_scale = None
self.per_layer_projection_scale = None
self.per_layer_model_projection = None
self.per_layer_projection_norm = None

Expand All @@ -451,6 +441,7 @@ def project_per_layer_inputs(
per_layer_inputs: Optional[mx.array] = None,
) -> mx.array:
per_layer_projection = self.per_layer_model_projection(inputs_embeds)
per_layer_projection = per_layer_projection * self.per_layer_projection_scale
per_layer_projection = per_layer_projection.reshape(
*inputs_embeds.shape[:-1],
self.config.num_hidden_layers,
Expand All @@ -476,7 +467,6 @@ def __call__(
h = self.embed_tokens(inputs)
h = h * self.embed_scale
else:

h = inputs_embeds

if self.hidden_size_per_layer_input:
Expand Down
57 changes: 52 additions & 5 deletions mlx_vlm/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@


class TestModels(unittest.TestCase):

def language_test_runner(self, model, model_type, vocab_size, num_layers):
self.assertEqual(model.model_type, model_type)
self.assertEqual(len(model.layers), num_layers)
Expand Down Expand Up @@ -103,7 +102,6 @@ def vision_test_runner(
# Check vision hidden feature layer's shape matches the expected hidden size
if channel_first:
if model_type == "llama4_vision_model":

self.assertEqual(hidden_states.shape[1], vision_hidden_size)
else:
self.assertEqual(hidden_states.shape[1], vision_hidden_size)
Expand Down Expand Up @@ -1926,7 +1924,13 @@ def test_gemma3(self):
)

def test_gemma4(self):
import tempfile
from pathlib import Path

from mlx_lm.utils import quantize_model

from mlx_vlm.models import gemma4
from mlx_vlm.utils import load_model, save_config, save_weights

text_config = gemma4.TextConfig(
model_type="gemma4_text",
Expand Down Expand Up @@ -1999,6 +2003,52 @@ def test_gemma4(self):
output = model(input_ids_with_img, pixel_values=pixel_values)
self.assertEqual(output.logits.shape, (1, 6, config.text_config.vocab_size))

# Quantized save/load regression for per-layer projection.
quant_model = gemma4.Model(config)

def quantize_per_layer_projection(path: str, _module: nn.Module):
return path == "language_model.model.per_layer_model_projection"

quant_model, quantized_config = quantize_model(
quant_model,
{
"model_type": "gemma4",
"vocab_size": config.vocab_size,
"image_token_id": config.image_token_id,
"audio_config": None,
"text_config": vars(text_config).copy(),
"vision_config": vars(vision_config).copy(),
},
group_size=32,
bits=4,
quant_predicate=quantize_per_layer_projection,
)
self.assertTrue(
hasattr(
quant_model.language_model.model.per_layer_model_projection, "scales"
)
)
quantized_config["quantization"][
"language_model.model.per_layer_model_projection"
] = {
"group_size": 32,
"bits": 4,
"mode": "affine",
}

with tempfile.TemporaryDirectory() as model_dir:
model_path = Path(model_dir)
save_weights(model_path, quant_model)
save_config(quantized_config, model_path / "config.json")
loaded = load_model(model_path)

self.assertTrue(
hasattr(loaded.language_model.model.per_layer_model_projection, "scales")
)
logits = loaded(mx.array([[1, 2, 3]], dtype=mx.int32)).logits
mx.eval(logits)
self.assertEqual(logits.shape, (1, 3, config.vocab_size))

# Full model forward: text + audio tokens
audio_config = gemma4.AudioConfig(
hidden_size=32,
Expand Down Expand Up @@ -4853,7 +4903,6 @@ def test_glm4v_moe_chunked_prefill_rope(self):


class TestMiniCPMO(unittest.TestCase):

@staticmethod
def _tiny_config():
from mlx_vlm.models import minicpmo
Expand Down Expand Up @@ -4960,7 +5009,6 @@ def test_minicpmo_sanitize_audio_conv_layout(self):


class TestPhi4MM(unittest.TestCase):

@staticmethod
def _tiny_config():
from mlx_vlm.models.phi4mm.config import ModelConfig, TextConfig, VisionConfig
Expand Down Expand Up @@ -5137,7 +5185,6 @@ def test_phi4mm_set_modality_skips_when_no_lora(self):


class TestSam3(unittest.TestCase):

# ─── SAM3 Tests ────────────────────────────────────────────

def test_sam3_config(self):
Expand Down
Loading