Skip to content

Commit ed8ee2c

Browse files
committed
Small Gemma draft
1 parent 3f9789f commit ed8ee2c

25 files changed

Lines changed: 2254 additions & 27 deletions

pytest.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ addopts =
1414
--ignore=tests/unit/dequantize_pack_quantized_int4_test.py
1515
--ignore=tests/unit/gemma3_layers_test.py
1616
--ignore=tests/unit/gemma4_layers_test.py
17+
--ignore=tests/unit/gemma4_small_layers_test.py
1718
--ignore=tests/unit/gpt_vs_reference_test.py
1819
--ignore=tests/unit/llama4_layers_test.py
1920
--ignore=tests/unit/hf_checkpoint_conversion_test.py

src/maxtext/checkpoint_conversion/utils/hf_model_configs.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,14 +144,137 @@
144144
)
145145

146146

147+
gemma4_e2b_dict = {
148+
"architectures": ["Gemma4ForConditionalGeneration"],
149+
"audio_config": None,
150+
"audio_token_id": 258881,
151+
"boa_token_id": 256000,
152+
"boi_token_id": 255999,
153+
"dtype": "bfloat16",
154+
"eoa_token_id": 258883,
155+
"eoa_token_index": 258883,
156+
"eoi_token_id": 258882,
157+
"eos_token_id": [1, 106],
158+
"image_token_id": 258880,
159+
"initializer_range": 0.02,
160+
"model_type": "gemma4",
161+
"text_config": {
162+
"attention_bias": False,
163+
"attention_dropout": 0.0,
164+
"attention_k_eq_v": False,
165+
"bos_token_id": 2,
166+
"dtype": "bfloat16",
167+
"enable_moe_block": False,
168+
"eos_token_id": 1,
169+
"expert_intermediate_size": None,
170+
"final_logit_softcapping": 30.0,
171+
"global_head_dim": 512,
172+
"head_dim": 256,
173+
"hidden_activation": "gelu_pytorch_tanh",
174+
"hidden_size": 1536,
175+
"hidden_size_per_layer_input": 256,
176+
"initializer_range": 0.02,
177+
"intermediate_size": 6144,
178+
"layer_types": [
179+
"sliding_attention",
180+
"sliding_attention",
181+
"sliding_attention",
182+
"sliding_attention",
183+
"full_attention",
184+
]
185+
* 7,
186+
"max_position_embeddings": 131072,
187+
"model_type": "gemma4_text",
188+
"num_attention_heads": 8,
189+
"num_experts": None,
190+
"num_global_key_value_heads": None,
191+
"num_hidden_layers": 35,
192+
"num_key_value_heads": 1,
193+
"num_kv_shared_layers": 20,
194+
"pad_token_id": 0,
195+
"rms_norm_eps": 1e-06,
196+
"rope_parameters": {
197+
"full_attention": {
198+
"partial_rotary_factor": 0.25,
199+
"rope_theta": 1_000_000.0,
200+
"rope_type": "proportional",
201+
},
202+
"sliding_attention": {"rope_theta": 10_000.0, "rope_type": "default"},
203+
},
204+
"sliding_window": 512,
205+
"tie_word_embeddings": True,
206+
"top_k_experts": None,
207+
"use_bidirectional_attention": None,
208+
"use_cache": True,
209+
"use_double_wide_mlp": True,
210+
"vocab_size": 262144,
211+
"vocab_size_per_layer_input": 262144,
212+
},
213+
"tie_word_embeddings": True,
214+
"transformers_version": "5.5.0.dev0",
215+
"video_token_id": 258884,
216+
"vision_config": {
217+
"attention_bias": False,
218+
"attention_dropout": 0.0,
219+
"default_output_length": 280,
220+
"dtype": "bfloat16",
221+
"global_head_dim": 64,
222+
"head_dim": 64,
223+
"hidden_activation": "gelu_pytorch_tanh",
224+
"hidden_size": 768,
225+
"intermediate_size": 3072,
226+
"max_position_embeddings": 131072,
227+
"model_type": "gemma4_vision",
228+
"num_attention_heads": 12,
229+
"num_hidden_layers": 16,
230+
"num_key_value_heads": 12,
231+
"patch_size": 16,
232+
"pooling_kernel_size": 3,
233+
"position_embedding_size": 10240,
234+
"rms_norm_eps": 1e-06,
235+
"rope_parameters": {"rope_theta": 100.0, "rope_type": "default"},
236+
"standardize": False,
237+
"use_clipped_linears": True,
238+
},
239+
"vision_soft_tokens_per_image": 280,
240+
}
241+
242+
243+
gemma4_e4b_dict = gemma4_e2b_dict.copy()
244+
gemma4_e4b_dict["text_config"] = gemma4_e2b_dict["text_config"].copy()
245+
gemma4_e4b_dict["text_config"].update(
246+
{
247+
"hidden_size": 2560,
248+
"intermediate_size": 10240,
249+
"layer_types": [
250+
"sliding_attention",
251+
"sliding_attention",
252+
"sliding_attention",
253+
"sliding_attention",
254+
"sliding_attention",
255+
"full_attention",
256+
]
257+
* 7,
258+
"num_hidden_layers": 42,
259+
"num_key_value_heads": 2,
260+
"num_kv_shared_layers": 18,
261+
"use_double_wide_mlp": False,
262+
}
263+
)
264+
265+
147266
try:
148267
# Will execute successfully if Transformers is updated with Gemma 4 support
149268
gemma4_26b_config = transformers.Gemma4Config(**gemma4_26b_dict)
150269
gemma4_31b_config = transformers.Gemma4Config(**gemma4_31b_dict)
270+
gemma4_e2b_config = transformers.Gemma4Config(**gemma4_e2b_dict)
271+
gemma4_e4b_config = transformers.Gemma4Config(**gemma4_e4b_dict)
151272
except AttributeError:
152273
# Graceful fallback to raw dict-based PTConfig if Gemma 4 natively is missing
153274
gemma4_26b_config = PTConfig(**gemma4_26b_dict) # pytype: disable=wrong-arg-types
154275
gemma4_31b_config = PTConfig(**gemma4_31b_dict) # pytype: disable=wrong-arg-types
276+
gemma4_e2b_config = PTConfig(**gemma4_e2b_dict) # pytype: disable=wrong-arg-types
277+
gemma4_e4b_config = PTConfig(**gemma4_e4b_dict) # pytype: disable=wrong-arg-types
155278

156279

157280
gemma3_4b_config = transformers.Gemma3Config(
@@ -1185,6 +1308,8 @@ def __init__(self, **kwargs):
11851308
"gemma3-27b": gemma3_27b_config,
11861309
"gemma4-26b": gemma4_26b_config,
11871310
"gemma4-31b": gemma4_31b_config,
1311+
"gemma4-e2b": gemma4_e2b_config,
1312+
"gemma4-e4b": gemma4_e4b_config,
11881313
"qwen2.5-1.5b": qwen25_1_5b_config,
11891314
"qwen2.5-7b": qwen25_7b_config,
11901315
"qwen2.5-14b": qwen25_14b_config,

src/maxtext/checkpoint_conversion/utils/hf_shape.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,92 @@ def GEMMA4_HF_WEIGHTS_TO_SHAPE(config):
284284
return shapes
285285

286286

287+
def GEMMA4_SMALL_HF_WEIGHTS_TO_SHAPE(config):
288+
"""Generates HF parameter shapes for Gemma 4 small (E2B / E4B).
289+
290+
Differs from GEMMA4_HF_WEIGHTS_TO_SHAPE in that it:
291+
* derives global-vs-sliding from the per-model ``layer_types`` list
292+
(E2B has period-5, E4B has period-6),
293+
* emits the Per-Layer-Embedding parameters when ``hidden_size_per_layer_input`` > 0,
294+
* omits k_proj/v_proj/k_norm/v_norm shapes on KV-shared layers, and
295+
* doubles ``intermediate_size`` on shared layers when ``use_double_wide_mlp``
296+
is set (E2B).
297+
"""
298+
shapes = {}
299+
300+
text_cfg = config.get("text_config", config)
301+
vision_cfg = config.get("vision_config", {})
302+
text_base = "model.language_model" if vision_cfg else "model"
303+
304+
hidden_size = text_cfg["hidden_size"]
305+
intermediate_size = text_cfg["intermediate_size"]
306+
num_hidden_layers = text_cfg["num_hidden_layers"]
307+
num_attention_heads = text_cfg["num_attention_heads"]
308+
num_key_value_heads = text_cfg["num_key_value_heads"]
309+
num_global_key_value_heads = text_cfg.get("num_global_key_value_heads") or num_key_value_heads
310+
head_dim = text_cfg["head_dim"]
311+
global_head_dim = text_cfg.get("global_head_dim", head_dim)
312+
vocab_size = text_cfg["vocab_size"]
313+
layer_types = text_cfg.get("layer_types", [])
314+
315+
ple_dim = text_cfg.get("hidden_size_per_layer_input", 0) or 0
316+
vocab_ple = text_cfg.get("vocab_size_per_layer_input", 0) or 0
317+
num_kv_shared = text_cfg.get("num_kv_shared_layers", 0) or 0
318+
first_shared = max(0, num_hidden_layers - num_kv_shared) if num_kv_shared > 0 else num_hidden_layers
319+
use_double_wide_mlp = bool(text_cfg.get("use_double_wide_mlp", False))
320+
321+
shapes[f"{text_base}.embed_tokens.weight"] = [vocab_size, hidden_size]
322+
shapes[f"{text_base}.norm.weight"] = [hidden_size]
323+
324+
if ple_dim > 0:
325+
shapes[f"{text_base}.embed_tokens_per_layer.weight"] = [vocab_ple, num_hidden_layers * ple_dim]
326+
shapes[f"{text_base}.per_layer_model_projection.weight"] = [num_hidden_layers * ple_dim, hidden_size]
327+
shapes[f"{text_base}.per_layer_projection_norm.weight"] = [ple_dim]
328+
329+
for i in range(num_hidden_layers):
330+
hf_prefix = f"{text_base}.layers.{i}"
331+
is_global = i < len(layer_types) and layer_types[i] == "full_attention"
332+
is_shared = num_kv_shared > 0 and i >= first_shared
333+
334+
if is_global:
335+
q_dim = num_attention_heads * global_head_dim
336+
kv_dim = num_global_key_value_heads * global_head_dim
337+
norm_dim = global_head_dim
338+
else:
339+
q_dim = num_attention_heads * head_dim
340+
kv_dim = num_key_value_heads * head_dim
341+
norm_dim = head_dim
342+
343+
shapes[f"{hf_prefix}.self_attn.q_proj.weight"] = [q_dim, hidden_size]
344+
shapes[f"{hf_prefix}.self_attn.o_proj.weight"] = [hidden_size, q_dim]
345+
shapes[f"{hf_prefix}.self_attn.q_norm.weight"] = [norm_dim]
346+
if not is_shared:
347+
shapes[f"{hf_prefix}.self_attn.k_proj.weight"] = [kv_dim, hidden_size]
348+
shapes[f"{hf_prefix}.self_attn.v_proj.weight"] = [kv_dim, hidden_size]
349+
shapes[f"{hf_prefix}.self_attn.k_norm.weight"] = [norm_dim]
350+
# v_norm only when scale is enabled in MaxText; param_mapping suppresses
351+
# this key otherwise, so emit the shape unconditionally — extras are ignored.
352+
shapes[f"{hf_prefix}.self_attn.v_norm.weight"] = [norm_dim]
353+
354+
shapes[f"{hf_prefix}.input_layernorm.weight"] = [hidden_size]
355+
shapes[f"{hf_prefix}.post_attention_layernorm.weight"] = [hidden_size]
356+
shapes[f"{hf_prefix}.pre_feedforward_layernorm.weight"] = [hidden_size]
357+
shapes[f"{hf_prefix}.post_feedforward_layernorm.weight"] = [hidden_size]
358+
shapes[f"{hf_prefix}.layer_scalar"] = [1]
359+
360+
mlp_dim = intermediate_size * 2 if (is_shared and use_double_wide_mlp) else intermediate_size
361+
shapes[f"{hf_prefix}.mlp.gate_proj.weight"] = [mlp_dim, hidden_size]
362+
shapes[f"{hf_prefix}.mlp.up_proj.weight"] = [mlp_dim, hidden_size]
363+
shapes[f"{hf_prefix}.mlp.down_proj.weight"] = [hidden_size, mlp_dim]
364+
365+
if ple_dim > 0:
366+
shapes[f"{hf_prefix}.per_layer_input_gate.weight"] = [ple_dim, hidden_size]
367+
shapes[f"{hf_prefix}.per_layer_projection.weight"] = [hidden_size, ple_dim]
368+
shapes[f"{hf_prefix}.post_per_layer_input_norm.weight"] = [hidden_size]
369+
370+
return shapes
371+
372+
287373
def GEMMA2_HF_WEIGHTS_TO_SHAPE(config):
288374
"""Returns mapping between HuggingFace weights path and weights shape.
289375
@@ -920,6 +1006,8 @@ def MIXTRAL_HF_WEIGHTS_TO_SHAPE(config):
9201006
"gemma3-27b": GEMMA3_HF_WEIGHTS_TO_SHAPE,
9211007
"gemma4-26b": GEMMA4_HF_WEIGHTS_TO_SHAPE,
9221008
"gemma4-31b": GEMMA4_HF_WEIGHTS_TO_SHAPE,
1009+
"gemma4-e2b": GEMMA4_SMALL_HF_WEIGHTS_TO_SHAPE,
1010+
"gemma4-e4b": GEMMA4_SMALL_HF_WEIGHTS_TO_SHAPE,
9231011
"qwen2.5-1.5b": QWEN_HF_WEIGHTS_TO_SHAPE,
9241012
"qwen2.5-7b": QWEN_HF_WEIGHTS_TO_SHAPE,
9251013
"qwen2.5-14b": QWEN_HF_WEIGHTS_TO_SHAPE,

0 commit comments

Comments
 (0)