|
| 1 | +# Gemma 4 E4B Examples |
| 2 | + |
| 3 | +This directory contains example scripts for the Gemma 4 E4B dense model. |
| 4 | + |
| 5 | +Gemma 4 E4B is a dense Gemma 4 variant with text, vision, and audio support in |
| 6 | +the Hugging Face checkpoint. The Bridge implementation keeps the text-only path |
| 7 | +and the vision/audio path separated: |
| 8 | + |
| 9 | +- `Gemma4ForCausalLM` is handled by `Gemma4Bridge` in |
| 10 | + `megatron.bridge.models.gemma`. |
| 11 | +- `Gemma4ForConditionalGeneration` is handled by `Gemma4VLBridge` in |
| 12 | + `megatron.bridge.models.gemma_vl`. |
| 13 | +- Shared language-model modules live under `megatron.bridge.models.gemma`; VL |
| 14 | + modules extend that implementation without introducing a reverse dependency. |
| 15 | + |
| 16 | +## Requirements |
| 17 | + |
| 18 | +Gemma 4 requires a Megatron-Core checkout on `PYTHONPATH`. The Bridge Gemma 4 |
| 19 | +provider is designed to work with a clean Megatron-Core checkout: Gemma 4 |
| 20 | +specific features such as dual RoPE, per-layer embeddings, shared KV, and |
| 21 | +embedding scaling are implemented or patched on the Bridge side rather than as |
| 22 | +Gemma 4 specific Megatron-Core arguments or `TransformerConfig` fields. |
| 23 | + |
| 24 | +Set `MEGATRON_LM_ROOT` to your Megatron-LM repository: |
| 25 | + |
| 26 | +```bash |
| 27 | +export MEGATRON_LM_ROOT=/path/to/Megatron-LM |
| 28 | +export PYTHONPATH=$PWD/src:${MEGATRON_LM_ROOT}:${PYTHONPATH:-} |
| 29 | +``` |
| 30 | + |
| 31 | +Gemma 4 checkpoints may require a recent `transformers` version: |
| 32 | + |
| 33 | +```bash |
| 34 | +uv pip install -q --upgrade 'transformers>=5.5.0' |
| 35 | +``` |
| 36 | + |
| 37 | +The conversion and inference scripts use `uv run --no-sync` where they depend on |
| 38 | +the current Python environment package versions. Distributed launch examples use |
| 39 | +`uv run python -m torch.distributed.run`, following the repository convention. |
| 40 | + |
| 41 | +## Workspace Configuration |
| 42 | + |
| 43 | +The examples below use a `WORKSPACE` environment variable to keep checkpoints, |
| 44 | +logs, and results in one place: |
| 45 | + |
| 46 | +```bash |
| 47 | +export WORKSPACE=/your/custom/path |
| 48 | +``` |
| 49 | + |
| 50 | +Suggested directory structure: |
| 51 | +- `${WORKSPACE}/models/` - Converted Megatron checkpoints |
| 52 | +- `${WORKSPACE}/results/` - Training outputs and experiment results |
| 53 | +- `${WORKSPACE}/logs/` - Parity and training logs |
| 54 | + |
| 55 | +`slurm_pretrain.sh` also requires `GEMMA4_LOG_ROOT` for parity and training |
| 56 | +logs: |
| 57 | + |
| 58 | +```bash |
| 59 | +export GEMMA4_LOG_ROOT=${WORKSPACE}/logs |
| 60 | +``` |
| 61 | + |
| 62 | +## Checkpoint Conversion |
| 63 | + |
| 64 | +Gemma 4 E4B has two useful conversion modes: |
| 65 | + |
| 66 | +- `GEMMA4_CONVERSION_MODE=text` imports the text-only GPTModel path, used for |
| 67 | + text pretraining and text generation. |
| 68 | +- `GEMMA4_CONVERSION_MODE=audio` imports the full VL/audio model path, used for |
| 69 | + multimodal parity checks. |
| 70 | + |
| 71 | +### Import HF → Megatron (text) |
| 72 | + |
| 73 | +```bash |
| 74 | +GEMMA4_CONVERSION_MODE=text \ |
| 75 | +uv run --no-sync python examples/conversion/convert_checkpoints.py import \ |
| 76 | + --hf-model google/gemma-4-E4B-it \ |
| 77 | + --megatron-path ${WORKSPACE}/models/gemma-4-E4B-it |
| 78 | +``` |
| 79 | + |
| 80 | +### Import HF → Megatron (VL/audio) |
| 81 | + |
| 82 | +```bash |
| 83 | +GEMMA4_CONVERSION_MODE=audio \ |
| 84 | +uv run --no-sync python examples/conversion/convert_checkpoints.py import \ |
| 85 | + --hf-model google/gemma-4-E4B-it \ |
| 86 | + --megatron-path ${WORKSPACE}/models/gemma-4-E4B-it-vl |
| 87 | +``` |
| 88 | + |
| 89 | +### Export Megatron → HF |
| 90 | + |
| 91 | +```bash |
| 92 | +uv run --no-sync python examples/conversion/convert_checkpoints.py export \ |
| 93 | + --hf-model google/gemma-4-E4B-it \ |
| 94 | + --megatron-path ${WORKSPACE}/models/gemma-4-E4B-it/iter_0000000 \ |
| 95 | + --hf-path ${WORKSPACE}/models/gemma-4-E4B-it-hf-export |
| 96 | +``` |
| 97 | + |
| 98 | +### Round-trip validation |
| 99 | + |
| 100 | +```bash |
| 101 | +GEMMA4_CONVERSION_MODE=text \ |
| 102 | +uv run --no-sync python -m torch.distributed.run --nproc_per_node=2 \ |
| 103 | + examples/conversion/hf_megatron_roundtrip_multi_gpu.py \ |
| 104 | + --hf-model-id google/gemma-4-E4B-it \ |
| 105 | + --output-dir ${WORKSPACE}/results/gemma-4-E4B-it-roundtrip \ |
| 106 | + --tp 2 --pp 1 |
| 107 | +``` |
| 108 | + |
| 109 | +See [conversion.sh](conversion.sh) for the full text-only import, export, and |
| 110 | +round-trip workflow. |
| 111 | + |
| 112 | +## Inference |
| 113 | + |
| 114 | +Text-only inference uses `hf_to_megatron_generate_text.py` with |
| 115 | +`GEMMA4_CONVERSION_MODE=text` so the bridge selects `Gemma4Bridge` and builds a |
| 116 | +`GPTModel`, not the full `Gemma4VLModel`. |
| 117 | + |
| 118 | +### Text generation from HF weights |
| 119 | + |
| 120 | +```bash |
| 121 | +GEMMA4_CONVERSION_MODE=text \ |
| 122 | +uv run --no-sync python -m torch.distributed.run --nproc_per_node=2 \ |
| 123 | + examples/conversion/hf_to_megatron_generate_text.py \ |
| 124 | + --hf_model_path google/gemma-4-E4B-it \ |
| 125 | + --prompt $'<start_of_turn>user\nWhat is the capital of France?<end_of_turn>\n<start_of_turn>model\n' \ |
| 126 | + --max_new_tokens 20 \ |
| 127 | + --tp 2 --pp 1 |
| 128 | +``` |
| 129 | + |
| 130 | +### Text generation from imported Megatron checkpoint |
| 131 | + |
| 132 | +```bash |
| 133 | +GEMMA4_CONVERSION_MODE=text \ |
| 134 | +uv run --no-sync python -m torch.distributed.run --nproc_per_node=2 \ |
| 135 | + examples/conversion/hf_to_megatron_generate_text.py \ |
| 136 | + --hf_model_path google/gemma-4-E4B-it \ |
| 137 | + --megatron_model_path ${WORKSPACE}/models/gemma-4-E4B-it/iter_0000000 \ |
| 138 | + --prompt $'<start_of_turn>user\nExplain entropy in one sentence.<end_of_turn>\n<start_of_turn>model\n' \ |
| 139 | + --max_new_tokens 50 \ |
| 140 | + --tp 2 --pp 1 |
| 141 | +``` |
| 142 | + |
| 143 | +See [inference.sh](inference.sh) for both examples. |
| 144 | + |
| 145 | +> **Note:** `google/gemma-4-E4B-it` is instruction tuned. For high-quality |
| 146 | +> assistant-style responses, use prompts and tokenization compatible with the |
| 147 | +> model's chat template. The simple generation script is intended as a Bridge |
| 148 | +> smoke test, not a production serving path. |
| 149 | +
|
| 150 | +## Parity Checks |
| 151 | + |
| 152 | +[parity_check_e4b.py](parity_check_e4b.py) compares Megatron logits against the |
| 153 | +Hugging Face model in three modes: |
| 154 | + |
| 155 | +| Mode | Megatron model | HF model | Checkpoint | |
| 156 | +|------|---------------|----------|------------| |
| 157 | +| `text` | `Gemma4DenseProvider` → `GPTModel` | `Gemma4ForCausalLM` | text checkpoint | |
| 158 | +| `vl` | `Gemma4DenseVLProvider` → `Gemma4VLModel` | `Gemma4ForConditionalGeneration` | VL/audio checkpoint | |
| 159 | +| `audio` | `Gemma4DenseVLProvider` → `Gemma4VLModel` | `Gemma4ForConditionalGeneration` | VL/audio checkpoint | |
| 160 | + |
| 161 | +### Text parity |
| 162 | + |
| 163 | +```bash |
| 164 | +CUDA_DEVICE_MAX_CONNECTIONS=1 uv run --no-sync python -m torch.distributed.run --nproc_per_node=2 \ |
| 165 | + examples/models/gemma/gemma4/parity_check_e4b.py \ |
| 166 | + --hf-dir /path/to/gemma-4-E4B-it \ |
| 167 | + --megatron-ckpt ${WORKSPACE}/models/gemma-4-E4B-it \ |
| 168 | + --tp 2 --bf16 --mode text --atol 3.0 |
| 169 | +``` |
| 170 | + |
| 171 | +### Audio parity |
| 172 | + |
| 173 | +```bash |
| 174 | +CUDA_DEVICE_MAX_CONNECTIONS=1 uv run --no-sync python -m torch.distributed.run --nproc_per_node=2 \ |
| 175 | + examples/models/gemma/gemma4/parity_check_e4b.py \ |
| 176 | + --hf-dir /path/to/gemma-4-E4B-it \ |
| 177 | + --megatron-ckpt ${WORKSPACE}/models/gemma-4-E4B-it-vl \ |
| 178 | + --tp 2 --bf16 --mode audio --atol 3.0 |
| 179 | +``` |
| 180 | + |
| 181 | +### Vision parity |
| 182 | + |
| 183 | +```bash |
| 184 | +CUDA_DEVICE_MAX_CONNECTIONS=1 uv run --no-sync python -m torch.distributed.run --nproc_per_node=2 \ |
| 185 | + examples/models/gemma/gemma4/parity_check_e4b.py \ |
| 186 | + --hf-dir /path/to/gemma-4-E4B-it \ |
| 187 | + --megatron-ckpt ${WORKSPACE}/models/gemma-4-E4B-it-vl \ |
| 188 | + --tp 2 --bf16 --mode vl --atol 6.0 |
| 189 | +``` |
| 190 | + |
| 191 | +Expected bf16 results: |
| 192 | + |
| 193 | +| Mode | Typical max \|diff\| | atol | Notes | |
| 194 | +|------|----------------------|------|-------| |
| 195 | +| text | ~2.94 | 3.0 | Softcap 30.0 applied before comparison | |
| 196 | +| audio | ~1.65 | 3.0 | 12 audio tokens | |
| 197 | +| vl | ~5.47 | 6.0 | 280 image tokens | |
| 198 | + |
| 199 | +The higher VL tolerance is expected. The image path injects many more modality |
| 200 | +tokens than the audio path, and bf16 vision feature differences accumulate |
| 201 | +through the language model. The worst positions are usually at the image/text |
| 202 | +boundary. |
| 203 | + |
| 204 | +## Pretraining |
| 205 | + |
| 206 | +[slurm_pretrain.sh](slurm_pretrain.sh) runs the full workflow: |
| 207 | + |
| 208 | +1. Convert the text checkpoint. |
| 209 | +2. Convert the VL/audio checkpoint. |
| 210 | +3. Run text, audio, and VL parity checks. |
| 211 | +4. Launch Gemma 4 E4B text pretraining. |
| 212 | + |
| 213 | +```bash |
| 214 | +HF_MODEL_DIR=/path/to/gemma-4-E4B-it \ |
| 215 | +MEGATRON_CKPT=${WORKSPACE}/models/gemma4-e4b-megatron \ |
| 216 | +GEMMA4_LOG_ROOT=${WORKSPACE}/logs \ |
| 217 | +TRAIN_DATA_PATH=/path/to/data \ |
| 218 | +bash examples/models/gemma/gemma4/slurm_pretrain.sh |
| 219 | +``` |
| 220 | + |
| 221 | +The script derives paths automatically: |
| 222 | +- `${MEGATRON_CKPT}-text` - text conversion, used for training |
| 223 | +- `${MEGATRON_CKPT}-vl` - VL/audio conversion, used for parity checks |
| 224 | + |
| 225 | +Skip flags: |
| 226 | +- `SKIP_CONVERT=1` |
| 227 | +- `SKIP_TEXT_CONVERT=1` |
| 228 | +- `SKIP_VL_CONVERT=1` |
| 229 | +- `SKIP_PARITY=1` |
| 230 | + |
| 231 | +## Evaluation |
| 232 | + |
| 233 | +Use the parity checks above as the primary conversion sanity tests. The text |
| 234 | +mode verifies the pure LLM path, while the `vl` and `audio` modes verify that |
| 235 | +the multimodal wrapper preserves the Hugging Face behavior. |
| 236 | + |
| 237 | +For generation sanity checks, run [inference.sh](inference.sh). For production |
| 238 | +serving, export the checkpoint to Hugging Face format and run it with a serving |
| 239 | +runtime that supports the Gemma 4 chat template and multimodal preprocessing. |
| 240 | + |
| 241 | +## Running Unit Tests |
| 242 | + |
| 243 | +```bash |
| 244 | +PYTHONPATH=$PWD/src:${MEGATRON_LM_ROOT}:${PYTHONPATH:-} uv run --no-sync python -m pytest \ |
| 245 | + tests/unit_tests/models/gemma/test_gemma4_bridge.py \ |
| 246 | + tests/unit_tests/models/gemma/test_gemma4_provider.py \ |
| 247 | + tests/unit_tests/models/gemma_vl/test_gemma4_vl_provider.py \ |
| 248 | + tests/unit_tests/models/gemma_vl/test_gemma4_vl_bridge.py \ |
| 249 | + tests/unit_tests/models/gemma_vl/test_gemma4_vl_modeling.py \ |
| 250 | + tests/unit_tests/recipes/test_gemma4_recipe.py \ |
| 251 | + -v |
| 252 | +``` |
| 253 | + |
| 254 | +Multi-GPU unit tests (TP=2, requires 2 GPUs): |
| 255 | + |
| 256 | +```bash |
| 257 | +NVIDIA_VISIBLE_DEVICES=0,1 uv run --no-sync python -m torch.distributed.run --nproc_per_node=2 \ |
| 258 | + -m pytest tests/unit_tests/models/gemma_vl -v -k "TensorParallel" |
| 259 | +``` |
| 260 | + |
| 261 | +## Architecture Notes |
| 262 | + |
| 263 | +### Clean Megatron-Core Compatibility |
| 264 | + |
| 265 | +Gemma 4 keeps model-specific behavior in Bridge: |
| 266 | + |
| 267 | +- `Gemma4DenseProvider` builds a standard `GPTModel`, then installs Gemma 4 |
| 268 | + dual RoPE, shared-KV wiring, PLE modules, and checkpoint load aliases. |
| 269 | +- `modeling_gemma4.py` patches only the created Gemma 4 decoder instance to |
| 270 | + thread `per_layer_inputs` through clean Megatron-Core's generic |
| 271 | + `extra_block_kwargs` path. |
| 272 | +- No Gemma 4 specific Megatron-Core CLI arguments or `TransformerConfig` fields |
| 273 | + are required for the dense text path. |
| 274 | + |
| 275 | +### Text and VL Separation |
| 276 | + |
| 277 | +The text-only implementation lives in `megatron.bridge.models.gemma`: |
| 278 | + |
| 279 | +- `modeling_gemma4.py` contains Dense/MoE layers, attention, dual RoPE, PLE, |
| 280 | + shared-KV wiring, and output softcapping. |
| 281 | +- `gemma4_provider.py` contains `Gemma4DenseProvider` and |
| 282 | + `Gemma4ModelProvider`. |
| 283 | +- `gemma4_bridge.py` registers `Gemma4ForCausalLM` and defines text checkpoint |
| 284 | + mappings. |
| 285 | + |
| 286 | +The VL implementation lives in `megatron.bridge.models.gemma_vl`: |
| 287 | + |
| 288 | +- `modeling_gemma4_vl.py` contains only `Gemma4VLModel` and VL/audio forward |
| 289 | + helpers. |
| 290 | +- `gemma4_vl_provider.py` contains `Gemma4DenseVLProvider` and |
| 291 | + `Gemma4VLModelProvider`. |
| 292 | +- `gemma4_vl_bridge.py` registers `Gemma4ForConditionalGeneration` and adds |
| 293 | + vision/audio mappings on top of the text mappings. |
| 294 | + |
| 295 | +`gemma_vl` imports from `gemma`; `gemma` does not import from `gemma_vl`. |
| 296 | + |
| 297 | +### Dense E4B Language Model |
| 298 | + |
| 299 | +| Component | Detail | |
| 300 | +|-----------|--------| |
| 301 | +| 4-norm structure | `input_layernorm` → attention → `post_self_attn_layernorm` → MLP → `post_mlp_layernorm` | |
| 302 | +| GQA + sliding/global mix | Sliding layers use 256-dim heads; global layers use 512-dim heads | |
| 303 | +| Dual RoPE | Sliding θ=10 000; global θ=1 000 000 with partial factor 0.25 | |
| 304 | +| Shared KV | Last 18 layers reuse KV from the last non-shared layer of the same attention type | |
| 305 | +| Per-Layer Embeddings | PLE modules are attached after `GPTModel` construction and threaded through `forward()` | |
| 306 | +| Logit softcapping | `final_logit_softcapping=30.0` is applied by the Gemma4 output layer | |
| 307 | + |
| 308 | +### VL and Audio Path |
| 309 | + |
| 310 | +`Gemma4VLModel` wraps the language model with HF vision/audio modules: |
| 311 | + |
| 312 | +- Vision tower and projector weights are mapped under `vision_tower.*` and |
| 313 | + `embed_vision.*`. |
| 314 | +- Audio tower and projector weights are mapped under `audio_tower.*` and |
| 315 | + `embed_audio.*`. |
| 316 | +- Multimodal token positions are replaced with pad token IDs before PLE lookup, |
| 317 | + matching Hugging Face behavior. |
0 commit comments