11#! /bin/bash
22
3- # This file contains an end-to-end Airflow nightly test, designed to run once a day on a v5p-8, along with documentation to guide users in getting started with Gemma3-4B .
3+ # Validates the Gemma3-4B SFT multimodal pipeline using a pre-converted MaxText checkpoint .
44
5- # The flow of this file is as follows:
6- # 1. Convert the checkpoint downloaded from Hugging Face to make it compatible with MaxText
7- # 2. Run multimodal decoding of Gemma3-4B, with the converted checkpoint.
8- # 3. Run supervised finetuning (SFT) of Gemma3-4B on ChartQA dataset with the converted checkpoint.
9- # 4. Run decoding from the finetuned checkpoint from step 3, seeing the short answer from SFT.
10- # 5. Convert the SFT checkpoint back to HuggingFace format.
5+ # The flow of this script is as follows:
6+ # 1. Run inference on the pre-converted checkpoint.
7+ # 2. Run SFT of Gemma3-4B on ChartQA dataset with the converted checkpoint.
8+ # 3. Run inference on the checkpoint produced by the SFT run.
9+ # 4. Convert the checkpoint produced by the SFT run back to HuggingFace format.
10+
11+ # Usage:
12+ # export HF_TOKEN=<your Hugging Face access token>
13+ # export RUN_ID=$(date +%Y-%m-%d-%H-%M-%S)
14+ # bash test_gemma3_to_mt.sh $RUN_ID true
15+ # bash test_gemma3_multimodal_sft.sh $RUN_ID
1116
1217# Note: You can stop at any step if you just want to run part of the flow.
1318
1419set -ex
15- idx=$( date +%Y-%m-%d-%H-%M)
20+
21+ run_id=${1:- $(date +% Y-% m-% d-% H-% M-% S)}
1622MODEL_NAME=' gemma3-4b'
17- export MODEL_VARIATION=' 4b'
18- HF_TOKEN=' ' # Important!!! Save your hf access token here
19- HF_GOLDEN_MODEL=' google/gemma-3-4b-pt'
20- TOKENIZER_PATH=" ${MAXTEXT_ASSETS_ROOT:- ${MAXTEXT_PKG_DIR:- ${MAXTEXT_REPO_ROOT:- $PWD } / src/ maxtext/ assets/ tokenizers} } " ' /tokenizer.gemma3'
21- # To convert the multimodal model, make sure the use_multimodal is set to be true
22- USE_MULTIMODAL=true
23- SCAN_LAYERS=false
24- SFT_STEPS=10
2523
26- # Installing torch for deps in forward_pass_logit_checker.py
27- python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu
24+ # Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to the GCS paths where you have the scanned and unscanned checkpoints stored
25+ BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs/${MODEL_NAME}
26+ UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY} /to_maxtext/unscanned/${run_id} /0/items
27+ SCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY} /to_maxtext/scanned/${run_id} /0/items
2828
29- # After downloading checkpoints, copy them to GCS bucket at $MODEL_BUCKET \
30- # Non-Googlers please remember to point these variables to GCS buckets that you own, this script uses internal buckets for testing.
31- export MODEL_BUCKET=gs ://maxtext-gemma/unified/gemma3
29+ # Step 1: Install torch and google-jetstream
30+ python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu
31+ python3 -m pip install google-jetstream@https ://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip --no-deps
3232
33- # 1. Convert the HuggingFace checkpoint to MaxText unscanned ckpt:
34- python3 -m maxtext.checkpoint_conversion.to_maxtext " ${MAXTEXT_CONFIGS_DIR :- ${MAXTEXT_REPO_ROOT :- $PWD } / src / maxtext / configs} " //base.yml \
33+ # Step 2: Run inference on the original checkpoint converted from Hugging Face
34+ python3 -m maxtext.inference.decode \
3535 model_name=${MODEL_NAME} \
36- hf_access_token=${HF_TOKEN} \
37- base_output_directory=${MODEL_BUCKET} /${MODEL_VARIATION} /unscanned/${idx} \
38- use_multimodal=${USE_MULTIMODAL} \
39- scan_layers=${SCAN_LAYERS}
40-
41- # 2. Decode the converted checkpoint to make sure it works
42- export UNSCANNED_CKPT_PATH=${MODEL_BUCKET} /${MODEL_VARIATION} /unscanned/${idx} /0/items
43- python3 -m maxtext.inference.decode " ${MAXTEXT_CONFIGS_DIR:- ${MAXTEXT_REPO_ROOT:- $PWD } / src/ maxtext/ configs} " //base.yml model_name=${MODEL_NAME} tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=ht_test max_prefill_predict_length=272 max_target_length=300 steps=1 async_checkpointing=false scan_layers=$SCAN_LAYERS use_multimodal=${USE_MULTIMODAL} prompt=\' Describe\ image\ \< start_of_image\>\' image_path=\' tests/assets/test_image.jpg\' attention=\' dot_product\'
36+ load_parameters_path=${UNSCANNED_CKPT_PATH} \
37+ per_device_batch_size=1 \
38+ run_name=${run_id} \
39+ max_prefill_predict_length=272 \
40+ max_target_length=300 \
41+ steps=1 \
42+ async_checkpointing=false \
43+ scan_layers=false \
44+ use_multimodal=True \
45+ tokenizer_type=huggingface \
46+ prompt=\' Describe\ image\ \< start_of_image\>\' \
47+ image_path=\' tests/assets/test_image.jpg\' \
48+ attention=\' dot_product\' skip_jax_distributed_system=True
4449
45- # 3. SFT the MaxText converted checkpoint on ChartQA dataset
46- export BASE_OUTPUT_DIRECTORY=${MODEL_BUCKET} /${MODEL_VARIATION} /unscanned/sft
47- python -m maxtext.trainers.post_train.sft.train_sft_native " ${MAXTEXT_CONFIGS_DIR:- ${MAXTEXT_REPO_ROOT:- $PWD } / src/ maxtext/ configs} " //sft-vision-chartqa.yml \
48- run_name=$idx \
49- model_name=$MODEL_NAME tokenizer_path=" google/gemma-3-4b-pt" \
50+ # Step 3: Run SFT on the MaxText checkpoint on ChartQA dataset
51+ python -m maxtext.trainers.post_train.sft.train_sft_native " ${MAXTEXT_CONFIGS_DIR:- ${MAXTEXT_REPO_ROOT:- $PWD } / src/ maxtext/ configs} " /post_train/sft-vision-chartqa.yml \
52+ run_name=${run_id} \
53+ model_name=${MODEL_NAME} \
5054 per_device_batch_size=1 \
5155 max_prefill_predict_length=1024 max_target_length=2048 \
52- steps=$SFT_STEPS \
53- scan_layers=$SCAN_LAYERS async_checkpointing=False \
56+ steps=5 \
57+ scan_layers=false async_checkpointing=False \
5458 attention=dot_product \
55- dataset_type=hf hf_path=parquet hf_access_token= $HF_TOKEN \
59+ dataset_type=hf hf_path=parquet \
5660 hf_train_files=gs://aireenmei-multipod/dataset/hf/chartqa/train-* \
57- base_output_directory=$BASE_OUTPUT_DIRECTORY \
58- load_parameters_path=$UNSCANNED_CKPT_PATH \
59- dtype=bfloat16 weight_dtype=bfloat16 sharding_tolerance=0.05
61+ base_output_directory=${BASE_OUTPUT_DIRECTORY} /multimodal/sft \
62+ load_parameters_path=${UNSCANNED_CKPT_PATH} \
63+ dtype=bfloat16 weight_dtype=bfloat16 sharding_tolerance=0.05 \
64+ checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False
6065
61- # 4. Decode from the finetuned checkpoint from step 3
62- export FINAL_CKPT_STEP=$(( SFT_STEPS - 1 ))
63- export FINETUNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY} /${idx} /checkpoints/${FINAL_CKPT_STEP} /items
64- python3 -m maxtext.inference.decode " ${MAXTEXT_CONFIGS_DIR:- ${MAXTEXT_REPO_ROOT:- $PWD } / src/ maxtext/ configs} " //base.yml model_name=${MODEL_NAME} tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${FINETUNED_CKPT_PATH} per_device_batch_size=1 run_name=ht_test max_prefill_predict_length=272 max_target_length=300 steps=1 async_checkpointing=false scan_layers=$SCAN_LAYERS use_multimodal=${USE_MULTIMODAL} prompt=\' Describe\ image\ \< start_of_image\>\' image_path=\' tests/assets/test_image.jpg\' attention=\' dot_product\'
66+ # Step 4: Run inference on the checkpoint generated from the previous run
67+ python3 -m maxtext.inference.decode \
68+ model_name=${MODEL_NAME} \
69+ load_parameters_path=${BASE_OUTPUT_DIRECTORY} /multimodal/sft/${run_id} /checkpoints/4/items \
70+ per_device_batch_size=1 \
71+ run_name=${run_id} } \
72+ max_prefill_predict_length=272 \
73+ max_target_length=300 \
74+ steps=1 \
75+ async_checkpointing=false \
76+ scan_layers=false \
77+ use_multimodal=true \
78+ prompt=\' Describe\ image\ \< start_of_image\>\' \
79+ image_path=\' tests/assets/test_image.jpg\' \
80+ attention=\' dot_product\'
6581
66- # 5. Convert the SFT checkpoint back to HuggingFace format.
67- export LOCAL_PATH=./tmp/hf/${MODEL_NAME} /${idx}
68- export CKPT_PATH=" gs://maxtext-gemma/unified/gemma3/4b/unscanned/sft/2025-08-08-18-28/2025-08-08-18-28/checkpoints/9/items"
69- python3 -m maxtext.checkpoint_conversion.to_huggingface " ${MAXTEXT_CONFIGS_DIR:- ${MAXTEXT_REPO_ROOT:- $PWD } / src/ maxtext/ configs} " //base.yml \
82+ # Step 5: Convert the SFT checkpoint back to HuggingFace format
83+ python3 -m maxtext.checkpoint_conversion.to_huggingface \
7084 model_name=${MODEL_NAME} \
71- hf_access_token=${HF_TOKEN} \
72- load_parameters_path=${CKPT_PATH} \
73- base_output_directory=${LOCAL_PATH} \
74- use_multimodal=${USE_MULTIMODAL} \
75- scan_layers=$SCAN_LAYERS
85+ load_parameters_path=${BASE_OUTPUT_DIRECTORY} /multimodal/sft/${run_id} /checkpoints/4/items \
86+ base_output_directory=${BASE_OUTPUT_DIRECTORY} /to_huggingface/unscanned/${run_id} \
87+ use_multimodal=true scan_layers=false
0 commit comments