|
1 | 1 | #!/bin/bash |
2 | 2 |
|
3 | | -# This file is both an integration test that runs once a day on a v4-8 and documentation for how to get started with Gemma3-4b. |
| 3 | +# Validates the Gemma3-4B pre-training pipeline using a pre-converted MaxText checkpoint. |
4 | 4 |
|
5 | | -# The flow of this file is as follows: |
6 | | -# 1. Convert the checkpoint downloaded from Kaggle to make it compatible with MaxText |
7 | | -# 2. Run decoding, finetuning of Gemma3-4b with the converted checkpoint. Also, run pretraining of Gemma3-4b |
8 | | -# 3. Convert the scanned checkpoint from step 1 into unscanned checkpoint format and run more efficient decoding. |
| 5 | +# The flow of this script is as follows: |
| 6 | +# 1. Run inference on the pre-converted checkpoint. |
| 7 | +# 2. Run pre-training starting from the pre-converted checkpoint. |
| 8 | +# 3. Run inference on the checkpoint produced by the pre-training run. |
| 9 | +# 4. Convert the checkpoint produced by the pre-training run back to HuggingFace format. |
| 10 | + |
| 11 | +# Usage: |
| 12 | +# export HF_TOKEN=<your Hugging Face access token> |
| 13 | +# export RUN_ID=$(date +%Y-%m-%d-%H-%M) |
| 14 | +# bash test_gemma3_to_mt.sh $RUN_ID |
| 15 | +# bash test_gemma3.sh $RUN_ID |
9 | 16 |
|
10 | 17 |
|
11 | 18 | set -ex |
12 | | -idx=$(date +%Y-%m-%d-%H-%M) |
13 | | -export MODEL_VARIATION='4b' |
14 | | -export MODEL_NAME=gemma3-${MODEL_VARIATION} |
15 | 19 |
|
16 | | -# Installing torch for deps in forward_pass_logit_checker |
17 | | -python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu |
| 20 | +run_id=${1:-$(date +%Y-%m-%d-%H-%M)} |
| 21 | +MODEL_NAME='gemma3-4b' |
| 22 | + |
| 23 | +# To convert the multimodal model, make sure the use_multimodal is set to be true |
| 24 | +USE_MULTIMODAL=false |
| 25 | + |
| 26 | +# Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to the GCS paths where you have the scanned and unscanned checkpoints stored |
| 27 | +BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs/${MODEL_NAME} |
| 28 | +UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/to_maxtext/unscanned/${run_id}/0/items |
18 | 29 |
|
19 | | -# After downloading checkpoints, copy them to GCS bucket at $CHKPT_BUCKET \ |
20 | | -# Non-Googlers please remember to use separate GCS paths for uploading model weights from kaggle ($CHKPT_BUCKET) and MaxText compatible weights ($MODEL_BUCKET). |
21 | | -# Non-Googlers please remember to point these variables to GCS buckets that you own, this script uses internal buckets for testing. |
22 | | -# You can use the Flax checkpoint available on Kaggle: |
23 | | -# https://www.kaggle.com/models/google/gemma-3/flax/ |
| 30 | +# Non-Googlers please remember to point `DATASET_PATH` to the GCS bucket where you have your training data |
| 31 | +DATASET_PATH=gs://maxtext-dataset |
| 32 | + |
| 33 | +# Step 1: Install torch |
| 34 | +python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu |
24 | 35 |
|
25 | | -export CHKPT_BUCKET=gs://maxtext-gemma/gemma3/flax |
26 | | -export MODEL_BUCKET=gs://maxtext-gemma/gemma3 |
| 36 | +# Step 2: Run inference on the original checkpoint converted from Hugging Face |
| 37 | +if [ ${USE_MULTIMODAL} == true ]; then |
| 38 | + python3 -m maxtext.inference.decode \ |
| 39 | + model_name=${MODEL_NAME} tokenizer_type="huggingface" \ |
| 40 | + load_parameters_path=${UNSCANNED_CKPT_PATH} \ |
| 41 | + per_device_batch_size=1 run_name=${run_id} \ |
| 42 | + max_prefill_predict_length=272 max_target_length=300 steps=1 async_checkpointing=false \ |
| 43 | + scan_layers=false use_multimodal=true \ |
| 44 | + checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False \ |
| 45 | + prompt=\'Describe\ image\ \<start_of_image\>\' image_path=\'tests/assets/test_image.jpg\' attention=\'dot_product\' |
| 46 | +else |
| 47 | + python3 -m maxtext.inference.decode \ |
| 48 | + model_name=${MODEL_NAME} tokenizer_type="huggingface" \ |
| 49 | + load_parameters_path=${UNSCANNED_CKPT_PATH} \ |
| 50 | + per_device_batch_size=1 run_name=${run_id} \ |
| 51 | + max_prefill_predict_length=8 max_target_length=16 steps=1 async_checkpointing=false \ |
| 52 | + checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False \ |
| 53 | + scan_layers=false prompt='I love to' attention=\'dot_product\' |
| 54 | +fi |
27 | 55 |
|
28 | | -python3 -m maxtext.checkpoint_conversion.standalone_scripts.convert_gemma3_chkpt --base_model_path ${CHKPT_BUCKET}/${MODEL_VARIATION} --maxtext_model_path ${MODEL_BUCKET}/${MODEL_VARIATION}/${idx} --model_size ${MODEL_VARIATION} |
| 56 | +# Step 3: Run Pre-training on the converted checkpoint |
| 57 | +# We can also run training by using the scanned converted checkpoint |
| 58 | +# Note that scanned checkpoint helps with efficient training |
| 59 | +python3 -m maxtext.trainers.pre_train.train \ |
| 60 | + base_output_directory=${BASE_OUTPUT_DIRECTORY}/train \ |
| 61 | + dataset_path=${DATASET_PATH} tokenizer_type="huggingface" \ |
| 62 | + load_parameters_path=${UNSCANNED_CKPT_PATH} \ |
| 63 | + per_device_batch_size=1 run_name=${run_id} \ |
| 64 | + max_target_length=8192 steps=5 async_checkpointing=false \ |
| 65 | + checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False \ |
| 66 | + model_name=${MODEL_NAME} scan_layers=false use_multimodal=${USE_MULTIMODAL} |
29 | 67 |
|
| 68 | +# Step 4: Run inference on the checkpoint generated from the previous run |
| 69 | +if [ ${USE_MULTIMODAL} == true ]; then |
| 70 | + python3 -m maxtext.inference.decode \ |
| 71 | + model_name=${MODEL_NAME} tokenizer_type="huggingface" \ |
| 72 | + load_parameters_path=${BASE_OUTPUT_DIRECTORY}/train/${run_id}/checkpoints/4/items \ |
| 73 | + per_device_batch_size=1 run_name=${run_id} \ |
| 74 | + max_prefill_predict_length=272 max_target_length=300 steps=1 async_checkpointing=false \ |
| 75 | + scan_layers=false use_multimodal=true \ |
| 76 | + checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False \ |
| 77 | + prompt=\'Describe\ image\ \<start_of_image\>\' image_path=\'tests/assets/test_image.jpg\' attention=\'dot_product\' |
| 78 | +else |
| 79 | + python3 -m maxtext.inference.decode \ |
| 80 | + model_name=${MODEL_NAME} tokenizer_type="huggingface" \ |
| 81 | + load_parameters_path=${BASE_OUTPUT_DIRECTORY}/train/${run_id}/checkpoints/4/items \ |
| 82 | + per_device_batch_size=1 run_name=${run_id} \ |
| 83 | + max_prefill_predict_length=8 max_target_length=16 steps=1 async_checkpointing=false \ |
| 84 | + checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False \ |
| 85 | + scan_layers=false prompt='I love to' attention=\'dot_product\' |
| 86 | +fi |
30 | 87 |
|
31 | | -# Non-Googlers please remember to point `DATASET_PATH` to the GCS bucket where you have your training data |
32 | | -export DATASET_PATH=gs://maxtext-dataset |
33 | | -# Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to a GCS bucket that you own, this bucket will store all the files generated by MaxText during a run |
34 | | -export BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs/gemma3-4b |
35 | | -# We define `CONVERTED_CHECKPOINT` to refer to the checkpoint subdirectory. This way it is easier to use this path in the `train` and `decode` commands |
36 | | -export CONVERTED_CHECKPOINT=${MODEL_BUCKET}/${MODEL_VARIATION}/${idx}/0/items |
37 | | -export RUN_NAME=unscanned_chkpt_${idx} |
38 | | -# Note that the `CONVERTED_CHECKPOINT` is in a `scanned` format which is great for training but for efficient decoding performance we want the checkpoint in an `unscanned` format. |
39 | | -# We can do this by running `maxtext.utils.generate_param_only_checkpoint` on `CONVERTED_CHECKPOINT` with `force_unroll=true`. |
40 | | -JAX_PLATFORMS=cpu python3 -m maxtext.utils.generate_param_only_checkpoint "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${RUN_NAME} model_name=${MODEL_NAME} force_unroll=true |
41 | | - |
42 | | -export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints/0/items |
43 | | - |
44 | | -# We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state. |
45 | | -# So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` |
46 | | -python3 -m maxtext.inference.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma3 load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=${MODEL_NAME} prompt="I love to" |
47 | | - |
48 | | -# # to get higher precision (eg. float32) run on CPU with `JAX_PLATFORMS=cpu` |
49 | | -python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma3 load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test_${MODEL_NAME} per_device_batch_size=1 model_name=${MODEL_NAME} max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic scan_layers=false --atol=1.0 --rtol=1.0 |
50 | | - |
51 | | -# Finetune by using the scanned converted checkpoint by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`. For Googlers, uncomment the line below if you want to use the pre-converted checkpoint. |
52 | | -# export CONVERTED_CHECKPOINT=gs://maxtext-model-checkpoints/gemma3-4b/2025-03-18-19-03/scanned/0/items |
53 | | -FINETUNE_RUN_NAME=runner_finetune_${idx} |
54 | | -python3 -m maxtext.trainers.pre_train.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml model_name=$MODEL_NAME base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} load_parameters_path=${CONVERTED_CHECKPOINT} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma3 per_device_batch_size=1 run_name=$FINETUNE_RUN_NAME steps=10 enable_checkpointing=true sharding_tolerance=0.03 |
55 | | - |
56 | | -# We also run pre-training, this is similar to the finetuning command except we don't pass any checkpoint directory to load_parameters_path |
57 | | -PRETRAIN_RUN_NAME=runner_pretrain_${idx} |
58 | | -python3 -m maxtext.trainers.pre_train.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml model_name=$MODEL_NAME base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma3 per_device_batch_size=1 run_name=$PRETRAIN_RUN_NAME steps=10 enable_checkpointing=false sharding_tolerance=0.03 |
| 88 | +# Step 5: Convert the checkpoint from MaxText format to Hugging Face format |
| 89 | +python3 -m maxtext.checkpoint_conversion.to_huggingface \ |
| 90 | + model_name=${MODEL_NAME} tokenizer_type="huggingface" \ |
| 91 | + load_parameters_path=${BASE_OUTPUT_DIRECTORY}/train/${run_id}/checkpoints/4/items \ |
| 92 | + base_output_directory=${BASE_OUTPUT_DIRECTORY}/to_huggingface/unscanned/${run_id} \ |
| 93 | + use_multimodal=${USE_MULTIMODAL} \ |
| 94 | + scan_layers=false |
0 commit comments