Finetune Gemma4 family of models with NeMo Automodel #2005
athitten
started this conversation in
Show and tell
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
NeMo Automodel added day 0 support for gemma4 family of models. Gemma4 family of models are the latest open-source models from Google targeted towards different types of hardware like phone, laptop and data center GPUs depending the use case. The small models google/gemma-4-E2B-it, google/gemma-4-E4B-it are trimodal supporting audio, image and text inputs while the larger ones: google/gemma-4-31B-it and google/gemma-4-26B-A4B-it support image and text inputs.
Key features of Gemma4:
gemma-4-E2B-itandgemma-4-E4B-itmaking effective parameters to be only 2B and 4B respectively .max_position_embeddings=131072), while 31B and 26B-A4B extend to 256K tokens (max_position_embeddings=262144). All variants use the split-RoPE scheme introduced in Gemma3:θ=10,000(default RoPE) on sliding layers andθ=1,000,000with a 0.25 proportional partial-rotary factor on global layers.gemma-4-E2B-it: 35 layers, hidden 1536, FFN intermediate 6144 (uses a double-wide MLP variant), 8 Q heads / 1 KV head.gemma-4-E4B-it: 42 layers, hidden 2560, FFN intermediate 10240, 8 Q heads / 2 KV heads.gemma-4-31B-it: 60 layers, hidden 5376, FFN intermediate 21504, 32 Q heads with 16 KV heads on sliding layers and 4 KV heads on global layers.gemma-4-26B-A4B-it: 30 layers, hidden 2816, and a Mixture-of-Experts FFN with 128 experts, top-8 routing, expert intermediate size 704 (dense FFN intermediate 2112); 16 Q heads with 8 KV heads on sliding layers and 2 KV heads on global layers.gemma-4-E2B-itshares KV projections across 20 of its 35 layers andgemma-4-E4B-itacross 18 of its 42 layers (num_kv_shared_layers), substantially reducing KV-cache memory for on-device inference. The 31B and 26B-A4B variants do not share KV across layers.head_dim=256, while global-attention layers use a widerhead_dim=512, giving global layers more capacity to aggregate long-range information while keeping local layers cheap.Finetuning Recipes:
We provide full fine-tuning, as well as PEFT recipes for all gemma4 variants:
gemma-4-E2B-it: gemma4_2b.yaml for fine-tuning and gemma4_2b_peft.yaml for PEFT with lora.gemma-4-E4B-it: gemma4_4b.yaml for fine-tuning and gemma4_4b_peft.yaml for PEFT with lora.gemma-4-31B-it: gemma4_31b.yaml for fine-tuning and gemma4_31b_peft.yaml for PEFT with lora. Both the recipes use FSDP2 with activation checkpointing.gemma-4-26B-A4B-it: gemma4_26b_a4b_moe.yaml for fine-tuning and gemma4_26b_a4b_moe_peft.yaml for PEFT with lora. Both the recipes use FSDP2 with expert parallelism (EP=8, 16 experts per GPU).Data
We use the MedPix-VQA dataset as an example. MedPix-VQA is a medical visual question-answering dataset built from the MedPix radiology image archive, pairing clinical images with diagnostic Q&A.
Below are the loss curves obtained when fine-tuning on MedPix-VQA with these recipes:
While these are single node recipes, to further scale the largest dense model in the family (
gemma-4-31B-it), we also provide recipes with tensor and pipeline parallelism: gemma4_31b_tp4.yaml, gemma4_31b_tp4_pp2.yaml and gemma4_31b_tp4_pp4.yaml.Many thanks to @HuiyingLi @khazic @sharonyu-115 @akoumpa for all contributions!!
Beta Was this translation helpful? Give feedback.
All reactions