From 7d85739191ce91f5c0ce48189ac3ca9bedf0f4e5 Mon Sep 17 00:00:00 2001 From: spicyneuron <183504714+spicyneuron@users.noreply.github.com> Date: Sun, 5 Apr 2026 21:05:48 +0800 Subject: [PATCH 1/2] Fix Gemma 4 quantized per-layer projection loading --- mlx_vlm/models/gemma4/language.py | 19 ++++-------- mlx_vlm/tests/test_models.py | 49 +++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 14 deletions(-) diff --git a/mlx_vlm/models/gemma4/language.py b/mlx_vlm/models/gemma4/language.py index ac419b9f1..b86c6de3a 100644 --- a/mlx_vlm/models/gemma4/language.py +++ b/mlx_vlm/models/gemma4/language.py @@ -352,18 +352,6 @@ def __call__( return h -class ScaledLinear(nn.Module): - """Linear layer with output scaling.""" - - def __init__(self, in_features: int, out_features: int, scalar: float): - super().__init__() - self.weight = mx.zeros((out_features, in_features)) - self.scalar = scalar - - def __call__(self, x: mx.array) -> mx.array: - return (x @ self.weight.T) * self.scalar - - class Gemma4TextModel(nn.Module): def __init__(self, config: TextConfig): super().__init__() @@ -418,10 +406,11 @@ def __init__(self, config: TextConfig): ) self.embed_tokens_per_layer_scale = config.hidden_size_per_layer_input**0.5 self.per_layer_input_scale = 2.0**-0.5 - self.per_layer_model_projection = ScaledLinear( + self.per_layer_projection_scale = config.hidden_size**-0.5 + self.per_layer_model_projection = nn.Linear( config.hidden_size, config.num_hidden_layers * config.hidden_size_per_layer_input, - scalar=config.hidden_size**-0.5, + bias=False, ) self.per_layer_projection_norm = RMSNormZeroShift( config.hidden_size_per_layer_input, eps=config.rms_norm_eps @@ -429,6 +418,7 @@ def __init__(self, config: TextConfig): else: self.embed_tokens_per_layer = None self.per_layer_input_scale = None + self.per_layer_projection_scale = None self.per_layer_model_projection = None self.per_layer_projection_norm = None @@ -447,6 +437,7 @@ def project_per_layer_inputs( per_layer_inputs: Optional[mx.array] = None, ) -> mx.array: per_layer_projection = self.per_layer_model_projection(inputs_embeds) + per_layer_projection = per_layer_projection * self.per_layer_projection_scale per_layer_projection = per_layer_projection.reshape( *inputs_embeds.shape[:-1], self.config.num_hidden_layers, diff --git a/mlx_vlm/tests/test_models.py b/mlx_vlm/tests/test_models.py index 6494a3c17..8895da327 100644 --- a/mlx_vlm/tests/test_models.py +++ b/mlx_vlm/tests/test_models.py @@ -1927,6 +1927,11 @@ def test_gemma3(self): def test_gemma4(self): from mlx_vlm.models import gemma4 + from mlx_lm.utils import quantize_model + from mlx_vlm.utils import load_model, save_config, save_weights + + import tempfile + from pathlib import Path text_config = gemma4.TextConfig( model_type="gemma4_text", @@ -1999,6 +2004,50 @@ def test_gemma4(self): output = model(input_ids_with_img, pixel_values=pixel_values) self.assertEqual(output.logits.shape, (1, 6, config.text_config.vocab_size)) + # Quantized save/load regression for per-layer projection. + quant_model = gemma4.Model(config) + + def quantize_per_layer_projection(path: str, _module: nn.Module): + return path == "language_model.model.per_layer_model_projection" + + quant_model, quantized_config = quantize_model( + quant_model, + { + "model_type": "gemma4", + "vocab_size": config.vocab_size, + "image_token_id": config.image_token_id, + "audio_config": None, + "text_config": vars(text_config).copy(), + "vision_config": vars(vision_config).copy(), + }, + group_size=32, + bits=4, + quant_predicate=quantize_per_layer_projection, + ) + self.assertTrue( + hasattr(quant_model.language_model.model.per_layer_model_projection, "scales") + ) + quantized_config["quantization"][ + "language_model.model.per_layer_model_projection" + ] = { + "group_size": 32, + "bits": 4, + "mode": "affine", + } + + with tempfile.TemporaryDirectory() as model_dir: + model_path = Path(model_dir) + save_weights(model_path, quant_model) + save_config(quantized_config, model_path / "config.json") + loaded = load_model(model_path) + + self.assertTrue( + hasattr(loaded.language_model.model.per_layer_model_projection, "scales") + ) + logits = loaded(mx.array([[1, 2, 3]], dtype=mx.int32)).logits + mx.eval(logits) + self.assertEqual(logits.shape, (1, 3, config.vocab_size)) + # Full model forward: text + audio tokens audio_config = gemma4.AudioConfig( hidden_size=32, From 1f9e930db9b931d4870057a927dddf3995712982 Mon Sep 17 00:00:00 2001 From: spicyneuron <183504714+spicyneuron@users.noreply.github.com> Date: Sun, 5 Apr 2026 21:31:25 +0800 Subject: [PATCH 2/2] Format --- mlx_vlm/models/gemma4/language.py | 1 - mlx_vlm/tests/test_models.py | 18 ++++++++---------- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/mlx_vlm/models/gemma4/language.py b/mlx_vlm/models/gemma4/language.py index b86c6de3a..4dfe15224 100644 --- a/mlx_vlm/models/gemma4/language.py +++ b/mlx_vlm/models/gemma4/language.py @@ -463,7 +463,6 @@ def __call__( h = self.embed_tokens(inputs) h = h * self.embed_scale else: - h = inputs_embeds if self.hidden_size_per_layer_input: diff --git a/mlx_vlm/tests/test_models.py b/mlx_vlm/tests/test_models.py index 8895da327..8ffb7b2ce 100644 --- a/mlx_vlm/tests/test_models.py +++ b/mlx_vlm/tests/test_models.py @@ -7,7 +7,6 @@ class TestModels(unittest.TestCase): - def language_test_runner(self, model, model_type, vocab_size, num_layers): self.assertEqual(model.model_type, model_type) self.assertEqual(len(model.layers), num_layers) @@ -103,7 +102,6 @@ def vision_test_runner( # Check vision hidden feature layer's shape matches the expected hidden size if channel_first: if model_type == "llama4_vision_model": - self.assertEqual(hidden_states.shape[1], vision_hidden_size) else: self.assertEqual(hidden_states.shape[1], vision_hidden_size) @@ -1926,13 +1924,14 @@ def test_gemma3(self): ) def test_gemma4(self): - from mlx_vlm.models import gemma4 - from mlx_lm.utils import quantize_model - from mlx_vlm.utils import load_model, save_config, save_weights - import tempfile from pathlib import Path + from mlx_lm.utils import quantize_model + + from mlx_vlm.models import gemma4 + from mlx_vlm.utils import load_model, save_config, save_weights + text_config = gemma4.TextConfig( model_type="gemma4_text", hidden_size=32, @@ -2025,7 +2024,9 @@ def quantize_per_layer_projection(path: str, _module: nn.Module): quant_predicate=quantize_per_layer_projection, ) self.assertTrue( - hasattr(quant_model.language_model.model.per_layer_model_projection, "scales") + hasattr( + quant_model.language_model.model.per_layer_model_projection, "scales" + ) ) quantized_config["quantization"][ "language_model.model.per_layer_model_projection" @@ -4815,7 +4816,6 @@ def test_glm4v_moe_chunked_prefill_rope(self): class TestMiniCPMO(unittest.TestCase): - @staticmethod def _tiny_config(): from mlx_vlm.models import minicpmo @@ -4922,7 +4922,6 @@ def test_minicpmo_sanitize_audio_conv_layout(self): class TestPhi4MM(unittest.TestCase): - @staticmethod def _tiny_config(): from mlx_vlm.models.phi4mm.config import ModelConfig, TextConfig, VisionConfig @@ -5099,7 +5098,6 @@ def test_phi4mm_set_modality_skips_when_no_lora(self): class TestSam3(unittest.TestCase): - # ─── SAM3 Tests ──────────────────────────────────────────── def test_sam3_config(self):