Skip to content

Commit a1fb834

Browse files
Merge pull request #3939 from AI-Hypercomputer:gemma3_4b_refactor
PiperOrigin-RevId: 918521707
2 parents 154e9b7 + 95f0c9b commit a1fb834

5 files changed

Lines changed: 246 additions & 158 deletions

File tree

Lines changed: 82 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,94 @@
11
#!/bin/bash
22

3-
# This file is both an integration test that runs once a day on a v4-8 and documentation for how to get started with Gemma3-4b.
3+
# Validates the Gemma3-4B pre-training pipeline using a pre-converted MaxText checkpoint.
44

5-
# The flow of this file is as follows:
6-
# 1. Convert the checkpoint downloaded from Kaggle to make it compatible with MaxText
7-
# 2. Run decoding, finetuning of Gemma3-4b with the converted checkpoint. Also, run pretraining of Gemma3-4b
8-
# 3. Convert the scanned checkpoint from step 1 into unscanned checkpoint format and run more efficient decoding.
5+
# The flow of this script is as follows:
6+
# 1. Run inference on the pre-converted checkpoint.
7+
# 2. Run pre-training starting from the pre-converted checkpoint.
8+
# 3. Run inference on the checkpoint produced by the pre-training run.
9+
# 4. Convert the checkpoint produced by the pre-training run back to HuggingFace format.
10+
11+
# Usage:
12+
# export HF_TOKEN=<your Hugging Face access token>
13+
# export RUN_ID=$(date +%Y-%m-%d-%H-%M)
14+
# bash test_gemma3_to_mt.sh $RUN_ID
15+
# bash test_gemma3.sh $RUN_ID
916

1017

1118
set -ex
12-
idx=$(date +%Y-%m-%d-%H-%M)
13-
export MODEL_VARIATION='4b'
14-
export MODEL_NAME=gemma3-${MODEL_VARIATION}
1519

16-
# Installing torch for deps in forward_pass_logit_checker
17-
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu
20+
run_id=${1:-$(date +%Y-%m-%d-%H-%M)}
21+
MODEL_NAME='gemma3-4b'
22+
23+
# To convert the multimodal model, make sure the use_multimodal is set to be true
24+
USE_MULTIMODAL=false
25+
26+
# Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to the GCS paths where you have the scanned and unscanned checkpoints stored
27+
BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs/${MODEL_NAME}
28+
UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/to_maxtext/unscanned/${run_id}/0/items
1829

19-
# After downloading checkpoints, copy them to GCS bucket at $CHKPT_BUCKET \
20-
# Non-Googlers please remember to use separate GCS paths for uploading model weights from kaggle ($CHKPT_BUCKET) and MaxText compatible weights ($MODEL_BUCKET).
21-
# Non-Googlers please remember to point these variables to GCS buckets that you own, this script uses internal buckets for testing.
22-
# You can use the Flax checkpoint available on Kaggle:
23-
# https://www.kaggle.com/models/google/gemma-3/flax/
30+
# Non-Googlers please remember to point `DATASET_PATH` to the GCS bucket where you have your training data
31+
DATASET_PATH=gs://maxtext-dataset
32+
33+
# Step 1: Install torch
34+
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu
2435

25-
export CHKPT_BUCKET=gs://maxtext-gemma/gemma3/flax
26-
export MODEL_BUCKET=gs://maxtext-gemma/gemma3
36+
# Step 2: Run inference on the original checkpoint converted from Hugging Face
37+
if [ ${USE_MULTIMODAL} == true ]; then
38+
python3 -m maxtext.inference.decode \
39+
model_name=${MODEL_NAME} tokenizer_type="huggingface" \
40+
load_parameters_path=${UNSCANNED_CKPT_PATH} \
41+
per_device_batch_size=1 run_name=${run_id} \
42+
max_prefill_predict_length=272 max_target_length=300 steps=1 async_checkpointing=false \
43+
scan_layers=false use_multimodal=true \
44+
checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False \
45+
prompt=\'Describe\ image\ \<start_of_image\>\' image_path=\'tests/assets/test_image.jpg\' attention=\'dot_product\'
46+
else
47+
python3 -m maxtext.inference.decode \
48+
model_name=${MODEL_NAME} tokenizer_type="huggingface" \
49+
load_parameters_path=${UNSCANNED_CKPT_PATH} \
50+
per_device_batch_size=1 run_name=${run_id} \
51+
max_prefill_predict_length=8 max_target_length=16 steps=1 async_checkpointing=false \
52+
checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False \
53+
scan_layers=false prompt='I love to' attention=\'dot_product\'
54+
fi
2755

28-
python3 -m maxtext.checkpoint_conversion.standalone_scripts.convert_gemma3_chkpt --base_model_path ${CHKPT_BUCKET}/${MODEL_VARIATION} --maxtext_model_path ${MODEL_BUCKET}/${MODEL_VARIATION}/${idx} --model_size ${MODEL_VARIATION}
56+
# Step 3: Run Pre-training on the converted checkpoint
57+
# We can also run training by using the scanned converted checkpoint
58+
# Note that scanned checkpoint helps with efficient training
59+
python3 -m maxtext.trainers.pre_train.train \
60+
base_output_directory=${BASE_OUTPUT_DIRECTORY}/train \
61+
dataset_path=${DATASET_PATH} tokenizer_type="huggingface" \
62+
load_parameters_path=${UNSCANNED_CKPT_PATH} \
63+
per_device_batch_size=1 run_name=${run_id} \
64+
max_target_length=8192 steps=5 async_checkpointing=false \
65+
checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False \
66+
model_name=${MODEL_NAME} scan_layers=false use_multimodal=${USE_MULTIMODAL}
2967

68+
# Step 4: Run inference on the checkpoint generated from the previous run
69+
if [ ${USE_MULTIMODAL} == true ]; then
70+
python3 -m maxtext.inference.decode \
71+
model_name=${MODEL_NAME} tokenizer_type="huggingface" \
72+
load_parameters_path=${BASE_OUTPUT_DIRECTORY}/train/${run_id}/checkpoints/4/items \
73+
per_device_batch_size=1 run_name=${run_id} \
74+
max_prefill_predict_length=272 max_target_length=300 steps=1 async_checkpointing=false \
75+
scan_layers=false use_multimodal=true \
76+
checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False \
77+
prompt=\'Describe\ image\ \<start_of_image\>\' image_path=\'tests/assets/test_image.jpg\' attention=\'dot_product\'
78+
else
79+
python3 -m maxtext.inference.decode \
80+
model_name=${MODEL_NAME} tokenizer_type="huggingface" \
81+
load_parameters_path=${BASE_OUTPUT_DIRECTORY}/train/${run_id}/checkpoints/4/items \
82+
per_device_batch_size=1 run_name=${run_id} \
83+
max_prefill_predict_length=8 max_target_length=16 steps=1 async_checkpointing=false \
84+
checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False \
85+
scan_layers=false prompt='I love to' attention=\'dot_product\'
86+
fi
3087

31-
# Non-Googlers please remember to point `DATASET_PATH` to the GCS bucket where you have your training data
32-
export DATASET_PATH=gs://maxtext-dataset
33-
# Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to a GCS bucket that you own, this bucket will store all the files generated by MaxText during a run
34-
export BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs/gemma3-4b
35-
# We define `CONVERTED_CHECKPOINT` to refer to the checkpoint subdirectory. This way it is easier to use this path in the `train` and `decode` commands
36-
export CONVERTED_CHECKPOINT=${MODEL_BUCKET}/${MODEL_VARIATION}/${idx}/0/items
37-
export RUN_NAME=unscanned_chkpt_${idx}
38-
# Note that the `CONVERTED_CHECKPOINT` is in a `scanned` format which is great for training but for efficient decoding performance we want the checkpoint in an `unscanned` format.
39-
# We can do this by running `maxtext.utils.generate_param_only_checkpoint` on `CONVERTED_CHECKPOINT` with `force_unroll=true`.
40-
JAX_PLATFORMS=cpu python3 -m maxtext.utils.generate_param_only_checkpoint "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${RUN_NAME} model_name=${MODEL_NAME} force_unroll=true
41-
42-
export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints/0/items
43-
44-
# We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state.
45-
# So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`
46-
python3 -m maxtext.inference.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma3 load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=${MODEL_NAME} prompt="I love to"
47-
48-
# # to get higher precision (eg. float32) run on CPU with `JAX_PLATFORMS=cpu`
49-
python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma3 load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test_${MODEL_NAME} per_device_batch_size=1 model_name=${MODEL_NAME} max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic scan_layers=false --atol=1.0 --rtol=1.0
50-
51-
# Finetune by using the scanned converted checkpoint by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`. For Googlers, uncomment the line below if you want to use the pre-converted checkpoint.
52-
# export CONVERTED_CHECKPOINT=gs://maxtext-model-checkpoints/gemma3-4b/2025-03-18-19-03/scanned/0/items
53-
FINETUNE_RUN_NAME=runner_finetune_${idx}
54-
python3 -m maxtext.trainers.pre_train.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml model_name=$MODEL_NAME base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} load_parameters_path=${CONVERTED_CHECKPOINT} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma3 per_device_batch_size=1 run_name=$FINETUNE_RUN_NAME steps=10 enable_checkpointing=true sharding_tolerance=0.03
55-
56-
# We also run pre-training, this is similar to the finetuning command except we don't pass any checkpoint directory to load_parameters_path
57-
PRETRAIN_RUN_NAME=runner_pretrain_${idx}
58-
python3 -m maxtext.trainers.pre_train.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml model_name=$MODEL_NAME base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma3 per_device_batch_size=1 run_name=$PRETRAIN_RUN_NAME steps=10 enable_checkpointing=false sharding_tolerance=0.03
88+
# Step 5: Convert the checkpoint from MaxText format to Hugging Face format
89+
python3 -m maxtext.checkpoint_conversion.to_huggingface \
90+
model_name=${MODEL_NAME} tokenizer_type="huggingface" \
91+
load_parameters_path=${BASE_OUTPUT_DIRECTORY}/train/${run_id}/checkpoints/4/items \
92+
base_output_directory=${BASE_OUTPUT_DIRECTORY}/to_huggingface/unscanned/${run_id} \
93+
use_multimodal=${USE_MULTIMODAL} \
94+
scan_layers=false
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
#!/bin/bash
2+
3+
# Validates the Gemma3-4B RL pipeline using a pre-converted MaxText checkpoint.
4+
5+
# The flow of this script is as follows:
6+
# 1. Run inference on the pre-converted checkpoint.
7+
# 2. Run RL starting from the pre-converted checkpoint.
8+
# 3. Run inference on the checkpoint produced by the RL run.
9+
# 4. Convert the checkpoint produced by the RL run back to HuggingFace format.
10+
11+
# Usage:
12+
# export HF_TOKEN=<your Hugging Face access token>
13+
# export RUN_ID=$(date +%Y-%m-%d-%H-%M)
14+
# bash test_gemma3_to_mt.sh $RUN_ID
15+
# bash test_gemma3_rl.sh $RUN_ID
16+
17+
18+
set -ex
19+
20+
run_id=${1:-$(date +%Y-%m-%d-%H-%M)}
21+
MODEL_NAME='gemma3-4b'
22+
23+
# Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to the GCS paths where you have the scanned and unscanned checkpoints stored
24+
BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs/${MODEL_NAME}
25+
UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/to_maxtext/unscanned/${run_id}/0/items
26+
SCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/to_maxtext/scanned/${run_id}/0/items
27+
28+
# Step 1: Install torch
29+
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu
30+
31+
# Step 2: Run inference on the original checkpoint converted from Hugging Face
32+
python3 -m maxtext.inference.vllm_decode \
33+
model_name=${MODEL_NAME} \
34+
load_parameters_path=${UNSCANNED_CKPT_PATH} \
35+
vllm_hf_overrides='{architectures: ["MaxTextForCausalLM"]}' \
36+
hbm_utilization_vllm=0.5 \
37+
prompt='Suggest some famous landmarks in London.' \
38+
use_chat_template=True scan_layers=false
39+
40+
# Step 3: Run RL on the converted checkpoint
41+
python3 -m maxtext.trainers.post_train.rl.train_rl \
42+
base_output_directory=${BASE_OUTPUT_DIRECTORY}/rl \
43+
load_parameters_path=${SCANNED_CKPT_PATH} \
44+
run_name=${run_id} rl.loss_algo='grpo' scan_layers=true \
45+
num_batches=5 batch_size=1 num_test_batches=5 \
46+
model_name=${MODEL_NAME} enable_single_controller=True \
47+
checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False \
48+
rollout_tensor_parallelism=1 \
49+
vllm_hf_overrides='{architectures: ["MaxTextForCausalLM"]}' \
50+
vllm_additional_config='{"maxtext_config": {"model_name": "gemma3-4b", "log_config": "false"}}'
51+
52+
53+
# Step 4: Run inference on the checkpoint generated from the previous run
54+
python3 -m maxtext.inference.vllm_decode \
55+
model_name=${MODEL_NAME} \
56+
load_parameters_path=${BASE_OUTPUT_DIRECTORY}/rl/${run_id}/checkpoints/actor/5/model_params \
57+
vllm_hf_overrides='{architectures: ["MaxTextForCausalLM"]}' \
58+
hbm_utilization_vllm=0.5 \
59+
prompt='Suggest some famous landmarks in London.' \
60+
use_chat_template=True scan_layers=true
61+
62+
# Step 5: Convert the checkpoint from MaxText format to Hugging Face format
63+
python3 -m maxtext.checkpoint_conversion.to_huggingface \
64+
model_name=${MODEL_NAME} \
65+
load_parameters_path=${BASE_OUTPUT_DIRECTORY}/rl/${run_id}/checkpoints/actor/5/model_params \
66+
base_output_directory=${BASE_OUTPUT_DIRECTORY}/to_huggingface/unscanned/${run_id} \
67+
use_multimodal=false scan_layers=true
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
#!/bin/bash
2+
3+
# Validates the Gemma3-4B SFT pipeline using a pre-converted MaxText checkpoint.
4+
5+
# The flow of this script is as follows:
6+
# 1. Run inference on the pre-converted checkpoint.
7+
# 2. Run SFT starting from the pre-converted checkpoint.
8+
# 3. Run inference on the checkpoint produced by the SFT run.
9+
# 4. Convert the checkpoint produced by the SFT run back to HuggingFace format.
10+
11+
# Usage:
12+
# export HF_TOKEN=<your Hugging Face access token>
13+
# export RUN_ID=$(date +%Y-%m-%d-%H-%M)
14+
# bash test_gemma3_to_mt.sh $RUN_ID
15+
# bash test_gemma3_sft.sh $RUN_ID
16+
17+
18+
set -ex
19+
20+
run_id=${1:-$(date +%Y-%m-%d-%H-%M)}
21+
MODEL_NAME='gemma3-4b'
22+
23+
# Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to the GCS paths where you have the scanned and unscanned checkpoints stored
24+
BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs/${MODEL_NAME}
25+
UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/to_maxtext/unscanned/${run_id}/0/items
26+
SCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/to_maxtext/scanned/${run_id}/0/items
27+
28+
# Step 1: Install torch
29+
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu
30+
31+
# Step 2: Run inference on the original checkpoint converted from Hugging Face
32+
python3 -m maxtext.inference.vllm_decode \
33+
model_name=${MODEL_NAME} \
34+
load_parameters_path=${UNSCANNED_CKPT_PATH} \
35+
vllm_hf_overrides='{architectures: ["MaxTextForCausalLM"]}' \
36+
hbm_utilization_vllm=0.5 \
37+
prompt="Suggest some famous landmarks in London." \
38+
use_chat_template=True scan_layers=false
39+
40+
# Step 3: Run SFT on the converted checkpoint
41+
python3 -m maxtext.trainers.post_train.sft.train_sft \
42+
base_output_directory=${BASE_OUTPUT_DIRECTORY}/sft \
43+
load_parameters_path=${SCANNED_CKPT_PATH} \
44+
per_device_batch_size=1 run_name=${run_id} \
45+
steps=5 scan_layers=true \
46+
model_name=${MODEL_NAME} enable_single_controller=True \
47+
checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False
48+
49+
# Step 4: Run inference on the checkpoint generated from the previous run
50+
python3 -m maxtext.inference.vllm_decode \
51+
model_name=${MODEL_NAME} \
52+
load_parameters_path=${BASE_OUTPUT_DIRECTORY}/sft/${run_id}/checkpoints/5/model_params \
53+
vllm_hf_overrides='{architectures: ["MaxTextForCausalLM"]}' \
54+
hbm_utilization_vllm=0.5 \
55+
prompt="Suggest some famous landmarks in London." \
56+
use_chat_template=True scan_layers=true
57+
58+
# Step 5: Convert the checkpoint from MaxText format to Hugging Face format
59+
python3 -m maxtext.checkpoint_conversion.to_huggingface \
60+
model_name=${MODEL_NAME} \
61+
load_parameters_path=${BASE_OUTPUT_DIRECTORY}/sft/${run_id}/checkpoints/5/model_params \
62+
base_output_directory=${BASE_OUTPUT_DIRECTORY}/to_huggingface/unscanned/${run_id} \
63+
use_multimodal=false scan_layers=true

tests/end_to_end/tpu/gemma3/4b/test_gemma3_to_hf.sh

Lines changed: 0 additions & 58 deletions
This file was deleted.

0 commit comments

Comments
 (0)