From a910675dec5a4b0a9f3bcd140b8dc1897aa13319 Mon Sep 17 00:00:00 2001 From: Gagik Amirkhanyan Date: Fri, 27 Mar 2026 15:38:55 -0700 Subject: [PATCH] fix gemma2 checkpoint conversion validation check --- src/maxtext/checkpoint_conversion/to_huggingface.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/maxtext/checkpoint_conversion/to_huggingface.py b/src/maxtext/checkpoint_conversion/to_huggingface.py index 94d808d162..73b2f66fb5 100644 --- a/src/maxtext/checkpoint_conversion/to_huggingface.py +++ b/src/maxtext/checkpoint_conversion/to_huggingface.py @@ -179,6 +179,17 @@ def _validate_or_update_architecture(hf_config, max_config, override: bool): if hf_value is None or mt_value is None: continue + # Special handling for Gemma 2 where local and global layers are bundled + if max_config.model_name.startswith("gemma2") and hf_attr == "num_hidden_layers": + if isinstance(mt_value, int): + mt_value = mt_value * 2 + + # Handle vocab size padding + if hf_attr == "vocab_size" and isinstance(mt_value, int) and isinstance(hf_value, int): + # MaxText often pads vocab size to a multiple of 128 or 256 for TPU efficiency + if mt_value >= hf_value and (mt_value - hf_value) < 256: + mt_value = hf_value + # Compare values (with tolerance for floats) is_match = False if isinstance(hf_value, float) or isinstance(mt_value, float):