Skip to content

Commit f0842ca

Browse files
Merge pull request #4011 from AI-Hypercomputer:multimodal_gemma3
PiperOrigin-RevId: 923520359
2 parents be4fd71 + 78f9a6c commit f0842ca

7 files changed

Lines changed: 107 additions & 76 deletions

File tree

src/maxtext/checkpoint_conversion/to_maxtext.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -924,6 +924,21 @@ def main(
924924
max_logging.log(f"HuggingFace model loaded. dtypes: {unique_dtypes}")
925925
print_ram_usage("After full HF model load")
926926

927+
# transformers>=5.8 removed the intermediate `vision_model` attribute,
928+
# so keys are now `model.vision_tower.embeddings.*` instead
929+
# of `model.vision_tower.vision_model.embeddings.*`.
930+
# Remap to the old format so that the param_mapping continues to work.
931+
if eager_load_method == "transformers" and config.use_multimodal:
932+
old_prefix = "model.vision_tower.vision_model."
933+
new_prefix = "model.vision_tower."
934+
needs_remap = any(k.startswith(new_prefix) and not k.startswith(old_prefix) for k in hf_state_dict_numpy)
935+
if needs_remap:
936+
max_logging.log("Detected new-style key layout; remapping vision_tower keys.")
937+
hf_state_dict_numpy = {
938+
(old_prefix + k[len(new_prefix) :] if k.startswith(new_prefix) and not k.startswith(old_prefix) else k): v
939+
for k, v in hf_state_dict_numpy.items()
940+
}
941+
927942
def _eager_getter(key):
928943
if key not in hf_state_dict_numpy:
929944
raise ValueError(f"HuggingFace key {key} not found in state_dict.")

src/maxtext/checkpoint_conversion/utils/hf_shape.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def GEMMA3_HF_WEIGHTS_TO_SHAPE(config):
5353
vision_patch_size = vision_config["patch_size"]
5454
vision_num_channels = vision_config["num_channels"]
5555
vision_image_size = vision_config["image_size"]
56-
vision_num_positions = (vision_image_size / vision_patch_size) ** 2
56+
vision_num_positions = (vision_image_size // vision_patch_size) ** 2
5757

5858
vocab_size = text_config["vocab_size"]
5959

tests/end_to_end/tpu/gemma3/4b/test_gemma3.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,14 @@
1010

1111
# Usage:
1212
# export HF_TOKEN=<your Hugging Face access token>
13-
# export RUN_ID=$(date +%Y-%m-%d-%H-%M)
13+
# export RUN_ID=$(date +%Y-%m-%d-%H-%M-%S)
1414
# bash test_gemma3_to_mt.sh $RUN_ID
1515
# bash test_gemma3.sh $RUN_ID
1616

1717

1818
set -ex
1919

20-
run_id=${1:-$(date +%Y-%m-%d-%H-%M)}
20+
run_id=${1:-$(date +%Y-%m-%d-%H-%M-%S)}
2121
MODEL_NAME='gemma3-4b'
2222

2323
# To convert the multimodal model, make sure the use_multimodal is set to be true
Lines changed: 67 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,75 +1,87 @@
11
#!/bin/bash
22

3-
# This file contains an end-to-end Airflow nightly test, designed to run once a day on a v5p-8, along with documentation to guide users in getting started with Gemma3-4B.
3+
# Validates the Gemma3-4B SFT multimodal pipeline using a pre-converted MaxText checkpoint.
44

5-
# The flow of this file is as follows:
6-
# 1. Convert the checkpoint downloaded from Hugging Face to make it compatible with MaxText
7-
# 2. Run multimodal decoding of Gemma3-4B, with the converted checkpoint.
8-
# 3. Run supervised finetuning (SFT) of Gemma3-4B on ChartQA dataset with the converted checkpoint.
9-
# 4. Run decoding from the finetuned checkpoint from step 3, seeing the short answer from SFT.
10-
# 5. Convert the SFT checkpoint back to HuggingFace format.
5+
# The flow of this script is as follows:
6+
# 1. Run inference on the pre-converted checkpoint.
7+
# 2. Run SFT of Gemma3-4B on ChartQA dataset with the converted checkpoint.
8+
# 3. Run inference on the checkpoint produced by the SFT run.
9+
# 4. Convert the checkpoint produced by the SFT run back to HuggingFace format.
10+
11+
# Usage:
12+
# export HF_TOKEN=<your Hugging Face access token>
13+
# export RUN_ID=$(date +%Y-%m-%d-%H-%M-%S)
14+
# bash test_gemma3_to_mt.sh $RUN_ID true
15+
# bash test_gemma3_multimodal_sft.sh $RUN_ID
1116

1217
# Note: You can stop at any step if you just want to run part of the flow.
1318

1419
set -ex
15-
idx=$(date +%Y-%m-%d-%H-%M)
20+
21+
run_id=${1:-$(date +%Y-%m-%d-%H-%M-%S)}
1622
MODEL_NAME='gemma3-4b'
17-
export MODEL_VARIATION='4b'
18-
HF_TOKEN='' # Important!!! Save your hf access token here
19-
HF_GOLDEN_MODEL='google/gemma-3-4b-pt'
20-
TOKENIZER_PATH="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"'/tokenizer.gemma3'
21-
# To convert the multimodal model, make sure the use_multimodal is set to be true
22-
USE_MULTIMODAL=true
23-
SCAN_LAYERS=false
24-
SFT_STEPS=10
2523

26-
# Installing torch for deps in forward_pass_logit_checker.py
27-
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu
24+
# Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to the GCS paths where you have the scanned and unscanned checkpoints stored
25+
BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs/${MODEL_NAME}
26+
UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/to_maxtext/unscanned/${run_id}/0/items
27+
SCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/to_maxtext/scanned/${run_id}/0/items
2828

29-
# After downloading checkpoints, copy them to GCS bucket at $MODEL_BUCKET \
30-
# Non-Googlers please remember to point these variables to GCS buckets that you own, this script uses internal buckets for testing.
31-
export MODEL_BUCKET=gs://maxtext-gemma/unified/gemma3
29+
# Step 1: Install torch and google-jetstream
30+
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu
31+
python3 -m pip install google-jetstream@https://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip --no-deps
3232

33-
# 1. Convert the HuggingFace checkpoint to MaxText unscanned ckpt:
34-
python3 -m maxtext.checkpoint_conversion.to_maxtext "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml \
33+
# Step 2: Run inference on the original checkpoint converted from Hugging Face
34+
python3 -m maxtext.inference.decode \
3535
model_name=${MODEL_NAME} \
36-
hf_access_token=${HF_TOKEN} \
37-
base_output_directory=${MODEL_BUCKET}/${MODEL_VARIATION}/unscanned/${idx} \
38-
use_multimodal=${USE_MULTIMODAL} \
39-
scan_layers=${SCAN_LAYERS}
40-
41-
# 2. Decode the converted checkpoint to make sure it works
42-
export UNSCANNED_CKPT_PATH=${MODEL_BUCKET}/${MODEL_VARIATION}/unscanned/${idx}/0/items
43-
python3 -m maxtext.inference.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml model_name=${MODEL_NAME} tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=ht_test max_prefill_predict_length=272 max_target_length=300 steps=1 async_checkpointing=false scan_layers=$SCAN_LAYERS use_multimodal=${USE_MULTIMODAL} prompt=\'Describe\ image\ \<start_of_image\>\' image_path=\'tests/assets/test_image.jpg\' attention=\'dot_product\'
36+
load_parameters_path=${UNSCANNED_CKPT_PATH} \
37+
per_device_batch_size=1 \
38+
run_name=${run_id} \
39+
max_prefill_predict_length=272 \
40+
max_target_length=300 \
41+
steps=1 \
42+
async_checkpointing=false \
43+
scan_layers=false \
44+
use_multimodal=True \
45+
tokenizer_type=huggingface \
46+
prompt=\'Describe\ image\ \<start_of_image\>\' \
47+
image_path=\'tests/assets/test_image.jpg\' \
48+
attention=\'dot_product\' skip_jax_distributed_system=True
4449

45-
# 3. SFT the MaxText converted checkpoint on ChartQA dataset
46-
export BASE_OUTPUT_DIRECTORY=${MODEL_BUCKET}/${MODEL_VARIATION}/unscanned/sft
47-
python -m maxtext.trainers.post_train.sft.train_sft_native "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//sft-vision-chartqa.yml \
48-
run_name=$idx \
49-
model_name=$MODEL_NAME tokenizer_path="google/gemma-3-4b-pt" \
50+
# Step 3: Run SFT on the MaxText checkpoint on ChartQA dataset
51+
python -m maxtext.trainers.post_train.sft.train_sft_native "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/post_train/sft-vision-chartqa.yml \
52+
run_name=${run_id} \
53+
model_name=${MODEL_NAME} \
5054
per_device_batch_size=1 \
5155
max_prefill_predict_length=1024 max_target_length=2048 \
52-
steps=$SFT_STEPS \
53-
scan_layers=$SCAN_LAYERS async_checkpointing=False \
56+
steps=5 \
57+
scan_layers=false async_checkpointing=False \
5458
attention=dot_product \
55-
dataset_type=hf hf_path=parquet hf_access_token=$HF_TOKEN \
59+
dataset_type=hf hf_path=parquet \
5660
hf_train_files=gs://aireenmei-multipod/dataset/hf/chartqa/train-* \
57-
base_output_directory=$BASE_OUTPUT_DIRECTORY \
58-
load_parameters_path=$UNSCANNED_CKPT_PATH \
59-
dtype=bfloat16 weight_dtype=bfloat16 sharding_tolerance=0.05
61+
base_output_directory=${BASE_OUTPUT_DIRECTORY}/multimodal/sft \
62+
load_parameters_path=${UNSCANNED_CKPT_PATH} \
63+
dtype=bfloat16 weight_dtype=bfloat16 sharding_tolerance=0.05 \
64+
checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False
6065

61-
# 4. Decode from the finetuned checkpoint from step 3
62-
export FINAL_CKPT_STEP=$((SFT_STEPS - 1))
63-
export FINETUNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${idx}/checkpoints/${FINAL_CKPT_STEP}/items
64-
python3 -m maxtext.inference.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml model_name=${MODEL_NAME} tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${FINETUNED_CKPT_PATH} per_device_batch_size=1 run_name=ht_test max_prefill_predict_length=272 max_target_length=300 steps=1 async_checkpointing=false scan_layers=$SCAN_LAYERS use_multimodal=${USE_MULTIMODAL} prompt=\'Describe\ image\ \<start_of_image\>\' image_path=\'tests/assets/test_image.jpg\' attention=\'dot_product\'
66+
# Step 4: Run inference on the checkpoint generated from the previous run
67+
python3 -m maxtext.inference.decode \
68+
model_name=${MODEL_NAME} \
69+
load_parameters_path=${BASE_OUTPUT_DIRECTORY}/multimodal/sft/${run_id}/checkpoints/4/items \
70+
per_device_batch_size=1 \
71+
run_name=${run_id}} \
72+
max_prefill_predict_length=272 \
73+
max_target_length=300 \
74+
steps=1 \
75+
async_checkpointing=false \
76+
scan_layers=false \
77+
use_multimodal=true \
78+
prompt=\'Describe\ image\ \<start_of_image\>\' \
79+
image_path=\'tests/assets/test_image.jpg\' \
80+
attention=\'dot_product\'
6581

66-
# 5. Convert the SFT checkpoint back to HuggingFace format.
67-
export LOCAL_PATH=./tmp/hf/${MODEL_NAME}/${idx}
68-
export CKPT_PATH="gs://maxtext-gemma/unified/gemma3/4b/unscanned/sft/2025-08-08-18-28/2025-08-08-18-28/checkpoints/9/items"
69-
python3 -m maxtext.checkpoint_conversion.to_huggingface "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml \
82+
# Step 5: Convert the SFT checkpoint back to HuggingFace format
83+
python3 -m maxtext.checkpoint_conversion.to_huggingface \
7084
model_name=${MODEL_NAME} \
71-
hf_access_token=${HF_TOKEN} \
72-
load_parameters_path=${CKPT_PATH} \
73-
base_output_directory=${LOCAL_PATH} \
74-
use_multimodal=${USE_MULTIMODAL} \
75-
scan_layers=$SCAN_LAYERS
85+
load_parameters_path=${BASE_OUTPUT_DIRECTORY}/multimodal/sft/${run_id}/checkpoints/4/items \
86+
base_output_directory=${BASE_OUTPUT_DIRECTORY}/to_huggingface/unscanned/${run_id} \
87+
use_multimodal=true scan_layers=false

tests/end_to_end/tpu/gemma3/4b/test_gemma3_rl.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,14 @@
1010

1111
# Usage:
1212
# export HF_TOKEN=<your Hugging Face access token>
13-
# export RUN_ID=$(date +%Y-%m-%d-%H-%M)
13+
# export RUN_ID=$(date +%Y-%m-%d-%H-%M-%S)
1414
# bash test_gemma3_to_mt.sh $RUN_ID
1515
# bash test_gemma3_rl.sh $RUN_ID
1616

1717

1818
set -ex
1919

20-
run_id=${1:-$(date +%Y-%m-%d-%H-%M)}
20+
run_id=${1:-$(date +%Y-%m-%d-%H-%M-%S)}
2121
MODEL_NAME='gemma3-4b'
2222

2323
# Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to the GCS paths where you have the scanned and unscanned checkpoints stored

tests/end_to_end/tpu/gemma3/4b/test_gemma3_sft.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,14 @@
1010

1111
# Usage:
1212
# export HF_TOKEN=<your Hugging Face access token>
13-
# export RUN_ID=$(date +%Y-%m-%d-%H-%M)
13+
# export RUN_ID=$(date +%Y-%m-%d-%H-%M-%S)
1414
# bash test_gemma3_to_mt.sh $RUN_ID
1515
# bash test_gemma3_sft.sh $RUN_ID
1616

1717

1818
set -ex
1919

20-
run_id=${1:-$(date +%Y-%m-%d-%H-%M)}
20+
run_id=${1:-$(date +%Y-%m-%d-%H-%M-%S)}
2121
MODEL_NAME='gemma3-4b'
2222

2323
# Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to the GCS paths where you have the scanned and unscanned checkpoints stored

tests/end_to_end/tpu/gemma3/4b/test_gemma3_to_mt.sh

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,21 @@
77
# 2. Convert the HuggingFace checkpoint to MaxText format in both unscanned and scanned formats.
88
# 3. Run a forward pass logits check to verify the converted checkpoint matches the original HF model.
99

10-
# Pre-requisites:
11-
# 1. Set HF_TOKEN environment variable to your Hugging Face access token with read permissions
12-
# export HF_TOKEN=<Hugging Face access token>
10+
# Usage:
11+
# export HF_TOKEN=<your Hugging Face access token>
12+
# export RUN_ID=$(date +%Y-%m-%d-%H-%M-%S)
13+
# bash test_gemma3_to_mt.sh $RUN_ID - to convert the checkpoint and run logit check for non-multimodal version
14+
# bash test_gemma3_to_mt.sh $RUN_ID true - to convert the checkpoint and run logit check for multimodal version
1315

1416

1517
set -ex
1618

17-
run_id=${1:-$(date +%Y-%m-%d-%H-%M)}
19+
run_id=${1:-$(date +%Y-%m-%d-%H-%M-%S)}
1820
MODEL_NAME='gemma3-4b'
1921
HF_GOLDEN_MODEL='google/gemma-3-4b-it'
2022

2123
# To convert the multimodal model, make sure the use_multimodal is set to be true
22-
USE_MULTIMODAL=false
24+
USE_MULTIMODAL=${2:-false}
2325

2426
# Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to the GCS paths where you want to store scanned and unscanned checkpoints
2527
BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs/${MODEL_NAME}/to_maxtext
@@ -58,12 +60,14 @@ echo "Scanned checkpoint path: ${SCANNED_CKPT_PATH}"
5860
# Step 3: Test whether the forward pass logits match the original HF model
5961
# to get higher precision (eg. float32) run on CPU with `JAX_PLATFORMS=cpu`
6062
# ToDo: improve forward_pass_logit_checker to test multi-modal prompt
61-
python3 -m tests.utils.forward_pass_logit_checker \
62-
load_parameters_path=${UNSCANNED_CKPT_PATH} \
63-
model_name=${MODEL_NAME} \
64-
use_multimodal=${USE_MULTIMODAL} \
65-
scan_layers=false \
66-
--hf_model_path=${HF_GOLDEN_MODEL} \
67-
--max_kl_div=0.03 \
68-
--run_hf_model=true \
69-
hardware=cpu skip_jax_distributed_system=True
63+
if [ "${USE_MULTIMODAL}" = "false" ]; then
64+
python3 -m tests.utils.forward_pass_logit_checker \
65+
load_parameters_path=${UNSCANNED_CKPT_PATH} \
66+
model_name=${MODEL_NAME} \
67+
use_multimodal=${USE_MULTIMODAL} \
68+
scan_layers=false \
69+
--hf_model_path=${HF_GOLDEN_MODEL} \
70+
--max_kl_div=0.03 \
71+
--run_hf_model=true \
72+
hardware=cpu skip_jax_distributed_system=True
73+
fi

0 commit comments

Comments
 (0)