Skip to content

Commit 586e692

Browse files
Merge pull request #3725 from AI-Hypercomputer:gagik-gemma4-readme
PiperOrigin-RevId: 904196908
2 parents f67d8b1 + a3708b8 commit 586e692

3 files changed

Lines changed: 14 additions & 4 deletions

File tree

docs/reference/models/supported_models_and_architectures.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ MaxText is an open-source, high-performance LLM framework written in Python/JAX.
3434

3535
### Gemma
3636

37-
- **Variants**: Gemma 1 (2B/7B), Gemma 2 (2B/9B/27B), **Gemma 3 (4B/12B/27B)** (text & multimodal)
38-
- **Notes**: RMSNorm; RoPE; GELU/SwiGLU; **QK-Norm** (Gemma 3); Local–Global interleaved attention; long-context scaling.
37+
- **Variants**: Gemma 1 (2B/7B), Gemma 2 (2B/9B/27B), Gemma 3 (4B/12B/27B), **Gemma 4 (31B Dense, MoE 26B-A4B)** (text & multimodal)
38+
- **Notes**: RMSNorm; RoPE; GELU/SwiGLU; **QK-Norm** (Gemma 3, 4); **Value Norm** (Gemma 4); Interleaved sliding-window & global attention (Gemma 3, 4); routed + shared experts (Gemma 4); long-context scaling.
3939

4040
### DeepSeek
4141

@@ -86,7 +86,7 @@ The following summarizes observed runtime efficiency and scaling behaviors of Ma
8686
- **Model Implementation Guides & Source Code:**
8787

8888
- **Llama**: [Guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/end_to_end/tpu/llama2/run_llama2.md) | [Llama2 and Llama3 Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/models/llama2.py) | [Llama4 Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/models/llama4.py)
89-
- **Gemma**: [Guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/end_to_end/tpu/gemma/Run_Gemma.md) | [Gemma Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/models/gemma.py) | [Gemma2 Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/models/gemma2.py) | [Gemma3 Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/models/gemma3.py)
89+
- **Gemma**: [Guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/end_to_end/tpu/gemma/Run_Gemma.md) | [Gemma Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/models/gemma.py) | [Gemma2 Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/models/gemma2.py) | [Gemma3 Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/models/gemma3.py) | [Gemma4 Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/models/gemma4.py)
9090
- **Mixtral**: [Guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/end_to_end/tpu/mixtral/Run_Mixtral.md) | [Mixtral Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/models/mixtral.py) | [Mistral Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/models/mistral.py)
9191
- **DeepSeek**: [Guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/end_to_end/tpu/deepseek/Run_DeepSeek.md) | [DeepSeek Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/models/deepseek.py)
9292
- **Qwen3**: [Guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/end_to_end/tpu/qwen/moe/run_qwen_moe.md) | [Qwen3-Next Guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/end_to_end/tpu/qwen/next/run_qwen3_next.md) | [Qwen3 Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/models/qwen3.py) | [Qwen3-Next Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/models/qwen3.py)

src/maxtext/configs/models/gemma4-26b.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ num_experts: 128
4545
num_experts_per_tok: 8
4646
shared_experts: 1
4747
norm_topk_prob: true
48+
load_balance_loss_weight: 0.001
4849

4950
# Multimodal flags (need to set use_multimodal=true)
5051
rope_theta_for_vit: 100

tests/end_to_end/tpu/gemma4/Run_Gemma4.md

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,17 @@ You can train from scratch to generate a new checkpoint. One example command to
2929
python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml model_name=gemma4-26b base_output_directory=${BASE_OUTPUT_DIRECTORY?} dataset_path=${DATASET_PATH?} tokenizer_path=google/gemma-4-26b-a4b-it per_device_batch_size=1 run_name=runner_pretrain_gemma4_26b steps=10 enable_checkpointing=false sharding_tolerance=0.03
3030
```
3131

32+
### Load balance loss (MoE only)
33+
Gemma4-26B is a Mixture-of-Experts model and uses an auxiliary load balance loss during training to encourage uniform expert utilization. The weight is controlled by `load_balance_loss_weight` and defaults to `0.001` in `src/maxtext/configs/models/gemma4-26b.yml`. To tune or disable it, override from the command line, for example:
34+
35+
```sh
36+
python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml model_name=gemma4-26b <other flags> load_balance_loss_weight=0.01
37+
```
38+
39+
Set `load_balance_loss_weight=0.0` to turn the auxiliary loss off. This flag has no effect on the dense Gemma4-31B model.
40+
3241
## Checkpoint Conversion
33-
To obtain the Gemma4 model weights, you can access them on Hugging Face (e.g., [google/gemma-4-31B-it](https://huggingface.co/google/gemma-4-31B-it)). You will need to accept the Gemma4 license through your Hugging Face account and provide your Hugging Face access token (as `HF_TOKEN`) for authentication. You can then convert them directly into a MaxText compatible format. Here's an example of converting the model weights using the conversion script (`tests/end_to_end/tpu/gemma4/26b/convert_gemma4_26b.sh`):
42+
To obtain the Gemma4 model weights, you can access them on Hugging Face (e.g., [google/gemma-4-31B-it](https://huggingface.co/google/gemma-4-31B-it)). You will need to accept the Gemma4 license through your Hugging Face account and provide your Hugging Face access token (as `HF_TOKEN`) for authentication. You can then convert them directly into a MaxText compatible format. Here's an example of converting the model weights using the conversion script (`tests/end_to_end/tpu/gemma4/26b/convert_gemma4.sh`):
3443

3544
```sh
3645
python3 -m maxtext.checkpoint_conversion.to_maxtext src/maxtext/configs/base.yml \

0 commit comments

Comments
 (0)