Skip to content

Commit 48817fb

Browse files
DOGEUNNKIMclaudeyaoyu-33
authored
feat(model): add Gemma-4 E4B support (layer spec, checkpoint loader, parity check) (#4148)
Signed-off-by: kdg6245 <kdg6245@snu.ac.kr> Signed-off-by: Dogeun Kim <82812668+DOGEUNNKIM@users.noreply.github.com> Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com> Co-authored-by: yaoyu-33 <yaoyu.094@gmail.com>
1 parent eeb2c60 commit 48817fb

22 files changed

Lines changed: 7475 additions & 2006 deletions
Lines changed: 317 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,317 @@
1+
# Gemma 4 E4B Examples
2+
3+
This directory contains example scripts for the Gemma 4 E4B dense model.
4+
5+
Gemma 4 E4B is a dense Gemma 4 variant with text, vision, and audio support in
6+
the Hugging Face checkpoint. The Bridge implementation keeps the text-only path
7+
and the vision/audio path separated:
8+
9+
- `Gemma4ForCausalLM` is handled by `Gemma4Bridge` in
10+
`megatron.bridge.models.gemma`.
11+
- `Gemma4ForConditionalGeneration` is handled by `Gemma4VLBridge` in
12+
`megatron.bridge.models.gemma_vl`.
13+
- Shared language-model modules live under `megatron.bridge.models.gemma`; VL
14+
modules extend that implementation without introducing a reverse dependency.
15+
16+
## Requirements
17+
18+
Gemma 4 requires a Megatron-Core checkout on `PYTHONPATH`. The Bridge Gemma 4
19+
provider is designed to work with a clean Megatron-Core checkout: Gemma 4
20+
specific features such as dual RoPE, per-layer embeddings, shared KV, and
21+
embedding scaling are implemented or patched on the Bridge side rather than as
22+
Gemma 4 specific Megatron-Core arguments or `TransformerConfig` fields.
23+
24+
Set `MEGATRON_LM_ROOT` to your Megatron-LM repository:
25+
26+
```bash
27+
export MEGATRON_LM_ROOT=/path/to/Megatron-LM
28+
export PYTHONPATH=$PWD/src:${MEGATRON_LM_ROOT}:${PYTHONPATH:-}
29+
```
30+
31+
Gemma 4 checkpoints may require a recent `transformers` version:
32+
33+
```bash
34+
uv pip install -q --upgrade 'transformers>=5.5.0'
35+
```
36+
37+
The conversion and inference scripts use `uv run --no-sync` where they depend on
38+
the current Python environment package versions. Distributed launch examples use
39+
`uv run python -m torch.distributed.run`, following the repository convention.
40+
41+
## Workspace Configuration
42+
43+
The examples below use a `WORKSPACE` environment variable to keep checkpoints,
44+
logs, and results in one place:
45+
46+
```bash
47+
export WORKSPACE=/your/custom/path
48+
```
49+
50+
Suggested directory structure:
51+
- `${WORKSPACE}/models/` - Converted Megatron checkpoints
52+
- `${WORKSPACE}/results/` - Training outputs and experiment results
53+
- `${WORKSPACE}/logs/` - Parity and training logs
54+
55+
`slurm_pretrain.sh` also requires `GEMMA4_LOG_ROOT` for parity and training
56+
logs:
57+
58+
```bash
59+
export GEMMA4_LOG_ROOT=${WORKSPACE}/logs
60+
```
61+
62+
## Checkpoint Conversion
63+
64+
Gemma 4 E4B has two useful conversion modes:
65+
66+
- `GEMMA4_CONVERSION_MODE=text` imports the text-only GPTModel path, used for
67+
text pretraining and text generation.
68+
- `GEMMA4_CONVERSION_MODE=audio` imports the full VL/audio model path, used for
69+
multimodal parity checks.
70+
71+
### Import HF → Megatron (text)
72+
73+
```bash
74+
GEMMA4_CONVERSION_MODE=text \
75+
uv run --no-sync python examples/conversion/convert_checkpoints.py import \
76+
--hf-model google/gemma-4-E4B-it \
77+
--megatron-path ${WORKSPACE}/models/gemma-4-E4B-it
78+
```
79+
80+
### Import HF → Megatron (VL/audio)
81+
82+
```bash
83+
GEMMA4_CONVERSION_MODE=audio \
84+
uv run --no-sync python examples/conversion/convert_checkpoints.py import \
85+
--hf-model google/gemma-4-E4B-it \
86+
--megatron-path ${WORKSPACE}/models/gemma-4-E4B-it-vl
87+
```
88+
89+
### Export Megatron → HF
90+
91+
```bash
92+
uv run --no-sync python examples/conversion/convert_checkpoints.py export \
93+
--hf-model google/gemma-4-E4B-it \
94+
--megatron-path ${WORKSPACE}/models/gemma-4-E4B-it/iter_0000000 \
95+
--hf-path ${WORKSPACE}/models/gemma-4-E4B-it-hf-export
96+
```
97+
98+
### Round-trip validation
99+
100+
```bash
101+
GEMMA4_CONVERSION_MODE=text \
102+
uv run --no-sync python -m torch.distributed.run --nproc_per_node=2 \
103+
examples/conversion/hf_megatron_roundtrip_multi_gpu.py \
104+
--hf-model-id google/gemma-4-E4B-it \
105+
--output-dir ${WORKSPACE}/results/gemma-4-E4B-it-roundtrip \
106+
--tp 2 --pp 1
107+
```
108+
109+
See [conversion.sh](conversion.sh) for the full text-only import, export, and
110+
round-trip workflow.
111+
112+
## Inference
113+
114+
Text-only inference uses `hf_to_megatron_generate_text.py` with
115+
`GEMMA4_CONVERSION_MODE=text` so the bridge selects `Gemma4Bridge` and builds a
116+
`GPTModel`, not the full `Gemma4VLModel`.
117+
118+
### Text generation from HF weights
119+
120+
```bash
121+
GEMMA4_CONVERSION_MODE=text \
122+
uv run --no-sync python -m torch.distributed.run --nproc_per_node=2 \
123+
examples/conversion/hf_to_megatron_generate_text.py \
124+
--hf_model_path google/gemma-4-E4B-it \
125+
--prompt $'<start_of_turn>user\nWhat is the capital of France?<end_of_turn>\n<start_of_turn>model\n' \
126+
--max_new_tokens 20 \
127+
--tp 2 --pp 1
128+
```
129+
130+
### Text generation from imported Megatron checkpoint
131+
132+
```bash
133+
GEMMA4_CONVERSION_MODE=text \
134+
uv run --no-sync python -m torch.distributed.run --nproc_per_node=2 \
135+
examples/conversion/hf_to_megatron_generate_text.py \
136+
--hf_model_path google/gemma-4-E4B-it \
137+
--megatron_model_path ${WORKSPACE}/models/gemma-4-E4B-it/iter_0000000 \
138+
--prompt $'<start_of_turn>user\nExplain entropy in one sentence.<end_of_turn>\n<start_of_turn>model\n' \
139+
--max_new_tokens 50 \
140+
--tp 2 --pp 1
141+
```
142+
143+
See [inference.sh](inference.sh) for both examples.
144+
145+
> **Note:** `google/gemma-4-E4B-it` is instruction tuned. For high-quality
146+
> assistant-style responses, use prompts and tokenization compatible with the
147+
> model's chat template. The simple generation script is intended as a Bridge
148+
> smoke test, not a production serving path.
149+
150+
## Parity Checks
151+
152+
[parity_check_e4b.py](parity_check_e4b.py) compares Megatron logits against the
153+
Hugging Face model in three modes:
154+
155+
| Mode | Megatron model | HF model | Checkpoint |
156+
|------|---------------|----------|------------|
157+
| `text` | `Gemma4DenseProvider``GPTModel` | `Gemma4ForCausalLM` | text checkpoint |
158+
| `vl` | `Gemma4DenseVLProvider``Gemma4VLModel` | `Gemma4ForConditionalGeneration` | VL/audio checkpoint |
159+
| `audio` | `Gemma4DenseVLProvider``Gemma4VLModel` | `Gemma4ForConditionalGeneration` | VL/audio checkpoint |
160+
161+
### Text parity
162+
163+
```bash
164+
CUDA_DEVICE_MAX_CONNECTIONS=1 uv run --no-sync python -m torch.distributed.run --nproc_per_node=2 \
165+
examples/models/gemma/gemma4/parity_check_e4b.py \
166+
--hf-dir /path/to/gemma-4-E4B-it \
167+
--megatron-ckpt ${WORKSPACE}/models/gemma-4-E4B-it \
168+
--tp 2 --bf16 --mode text --atol 3.0
169+
```
170+
171+
### Audio parity
172+
173+
```bash
174+
CUDA_DEVICE_MAX_CONNECTIONS=1 uv run --no-sync python -m torch.distributed.run --nproc_per_node=2 \
175+
examples/models/gemma/gemma4/parity_check_e4b.py \
176+
--hf-dir /path/to/gemma-4-E4B-it \
177+
--megatron-ckpt ${WORKSPACE}/models/gemma-4-E4B-it-vl \
178+
--tp 2 --bf16 --mode audio --atol 3.0
179+
```
180+
181+
### Vision parity
182+
183+
```bash
184+
CUDA_DEVICE_MAX_CONNECTIONS=1 uv run --no-sync python -m torch.distributed.run --nproc_per_node=2 \
185+
examples/models/gemma/gemma4/parity_check_e4b.py \
186+
--hf-dir /path/to/gemma-4-E4B-it \
187+
--megatron-ckpt ${WORKSPACE}/models/gemma-4-E4B-it-vl \
188+
--tp 2 --bf16 --mode vl --atol 6.0
189+
```
190+
191+
Expected bf16 results:
192+
193+
| Mode | Typical max \|diff\| | atol | Notes |
194+
|------|----------------------|------|-------|
195+
| text | ~2.94 | 3.0 | Softcap 30.0 applied before comparison |
196+
| audio | ~1.65 | 3.0 | 12 audio tokens |
197+
| vl | ~5.47 | 6.0 | 280 image tokens |
198+
199+
The higher VL tolerance is expected. The image path injects many more modality
200+
tokens than the audio path, and bf16 vision feature differences accumulate
201+
through the language model. The worst positions are usually at the image/text
202+
boundary.
203+
204+
## Pretraining
205+
206+
[slurm_pretrain.sh](slurm_pretrain.sh) runs the full workflow:
207+
208+
1. Convert the text checkpoint.
209+
2. Convert the VL/audio checkpoint.
210+
3. Run text, audio, and VL parity checks.
211+
4. Launch Gemma 4 E4B text pretraining.
212+
213+
```bash
214+
HF_MODEL_DIR=/path/to/gemma-4-E4B-it \
215+
MEGATRON_CKPT=${WORKSPACE}/models/gemma4-e4b-megatron \
216+
GEMMA4_LOG_ROOT=${WORKSPACE}/logs \
217+
TRAIN_DATA_PATH=/path/to/data \
218+
bash examples/models/gemma/gemma4/slurm_pretrain.sh
219+
```
220+
221+
The script derives paths automatically:
222+
- `${MEGATRON_CKPT}-text` - text conversion, used for training
223+
- `${MEGATRON_CKPT}-vl` - VL/audio conversion, used for parity checks
224+
225+
Skip flags:
226+
- `SKIP_CONVERT=1`
227+
- `SKIP_TEXT_CONVERT=1`
228+
- `SKIP_VL_CONVERT=1`
229+
- `SKIP_PARITY=1`
230+
231+
## Evaluation
232+
233+
Use the parity checks above as the primary conversion sanity tests. The text
234+
mode verifies the pure LLM path, while the `vl` and `audio` modes verify that
235+
the multimodal wrapper preserves the Hugging Face behavior.
236+
237+
For generation sanity checks, run [inference.sh](inference.sh). For production
238+
serving, export the checkpoint to Hugging Face format and run it with a serving
239+
runtime that supports the Gemma 4 chat template and multimodal preprocessing.
240+
241+
## Running Unit Tests
242+
243+
```bash
244+
PYTHONPATH=$PWD/src:${MEGATRON_LM_ROOT}:${PYTHONPATH:-} uv run --no-sync python -m pytest \
245+
tests/unit_tests/models/gemma/test_gemma4_bridge.py \
246+
tests/unit_tests/models/gemma/test_gemma4_provider.py \
247+
tests/unit_tests/models/gemma_vl/test_gemma4_vl_provider.py \
248+
tests/unit_tests/models/gemma_vl/test_gemma4_vl_bridge.py \
249+
tests/unit_tests/models/gemma_vl/test_gemma4_vl_modeling.py \
250+
tests/unit_tests/recipes/test_gemma4_recipe.py \
251+
-v
252+
```
253+
254+
Multi-GPU unit tests (TP=2, requires 2 GPUs):
255+
256+
```bash
257+
NVIDIA_VISIBLE_DEVICES=0,1 uv run --no-sync python -m torch.distributed.run --nproc_per_node=2 \
258+
-m pytest tests/unit_tests/models/gemma_vl -v -k "TensorParallel"
259+
```
260+
261+
## Architecture Notes
262+
263+
### Clean Megatron-Core Compatibility
264+
265+
Gemma 4 keeps model-specific behavior in Bridge:
266+
267+
- `Gemma4DenseProvider` builds a standard `GPTModel`, then installs Gemma 4
268+
dual RoPE, shared-KV wiring, PLE modules, and checkpoint load aliases.
269+
- `modeling_gemma4.py` patches only the created Gemma 4 decoder instance to
270+
thread `per_layer_inputs` through clean Megatron-Core's generic
271+
`extra_block_kwargs` path.
272+
- No Gemma 4 specific Megatron-Core CLI arguments or `TransformerConfig` fields
273+
are required for the dense text path.
274+
275+
### Text and VL Separation
276+
277+
The text-only implementation lives in `megatron.bridge.models.gemma`:
278+
279+
- `modeling_gemma4.py` contains Dense/MoE layers, attention, dual RoPE, PLE,
280+
shared-KV wiring, and output softcapping.
281+
- `gemma4_provider.py` contains `Gemma4DenseProvider` and
282+
`Gemma4ModelProvider`.
283+
- `gemma4_bridge.py` registers `Gemma4ForCausalLM` and defines text checkpoint
284+
mappings.
285+
286+
The VL implementation lives in `megatron.bridge.models.gemma_vl`:
287+
288+
- `modeling_gemma4_vl.py` contains only `Gemma4VLModel` and VL/audio forward
289+
helpers.
290+
- `gemma4_vl_provider.py` contains `Gemma4DenseVLProvider` and
291+
`Gemma4VLModelProvider`.
292+
- `gemma4_vl_bridge.py` registers `Gemma4ForConditionalGeneration` and adds
293+
vision/audio mappings on top of the text mappings.
294+
295+
`gemma_vl` imports from `gemma`; `gemma` does not import from `gemma_vl`.
296+
297+
### Dense E4B Language Model
298+
299+
| Component | Detail |
300+
|-----------|--------|
301+
| 4-norm structure | `input_layernorm` → attention → `post_self_attn_layernorm` → MLP → `post_mlp_layernorm` |
302+
| GQA + sliding/global mix | Sliding layers use 256-dim heads; global layers use 512-dim heads |
303+
| Dual RoPE | Sliding θ=10 000; global θ=1 000 000 with partial factor 0.25 |
304+
| Shared KV | Last 18 layers reuse KV from the last non-shared layer of the same attention type |
305+
| Per-Layer Embeddings | PLE modules are attached after `GPTModel` construction and threaded through `forward()` |
306+
| Logit softcapping | `final_logit_softcapping=30.0` is applied by the Gemma4 output layer |
307+
308+
### VL and Audio Path
309+
310+
`Gemma4VLModel` wraps the language model with HF vision/audio modules:
311+
312+
- Vision tower and projector weights are mapped under `vision_tower.*` and
313+
`embed_vision.*`.
314+
- Audio tower and projector weights are mapped under `audio_tower.*` and
315+
`embed_audio.*`.
316+
- Multimodal token positions are replaced with pad token IDs before PLE lookup,
317+
matching Hugging Face behavior.
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#!/usr/bin/env bash
2+
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
# Workspace directory for checkpoints and results
17+
WORKSPACE=${WORKSPACE:-/workspace}
18+
19+
# Force text-only bridge (Gemma4ForCausalLM / Gemma4DenseProvider).
20+
# gemma-4-E4B-it is Gemma4ForConditionalGeneration in HF; without this flag
21+
# the VL bridge is selected and vision/audio modules are imported unnecessarily.
22+
export GEMMA4_CONVERSION_MODE=text
23+
24+
# Import HF → Megatron (Dense E4B base model)
25+
uv run --no-sync python examples/conversion/convert_checkpoints.py import \
26+
--hf-model google/gemma-4-E4B-it \
27+
--megatron-path ${WORKSPACE}/models/gemma-4-E4B-it
28+
29+
# Export Megatron → HF
30+
uv run --no-sync python examples/conversion/convert_checkpoints.py export \
31+
--hf-model google/gemma-4-E4B-it \
32+
--megatron-path ${WORKSPACE}/models/gemma-4-E4B-it/iter_0000000 \
33+
--hf-path ${WORKSPACE}/models/gemma-4-E4B-it-hf-export
34+
35+
# Round-trip validation
36+
uv run --no-sync python -m torch.distributed.run --nproc_per_node=2 examples/conversion/hf_megatron_roundtrip_multi_gpu.py \
37+
--hf-model-id google/gemma-4-E4B-it \
38+
--output-dir ${WORKSPACE}/results/gemma-4-E4B-it-roundtrip \
39+
--tp 2 --pp 1

0 commit comments

Comments
 (0)